Git Product home page Git Product logo

mflag's Introduction

Overview

Quick Start

How to use

git clone [email protected]:laihuiyuan/mFLAG.git
cd mFLAG
from model import MultiFigurativeGeneration
from tokenization_mflag import MFlagTokenizerFast
tokenizer = MFlagTokenizerFast.from_pretrained('laihuiyuan/mFLAG')
model = MultiFigurativeGeneration.from_pretrained('laihuiyuan/mFLAG')


# an example for hyperbole-to-sarcasm generation
# a token (<hyperbole>) is added at the beginning of the source sentence to indicate its figure of speech
inp_ids = tokenizer.encode("<hyperbole> I am not happy that he urged me to finish all the hardest tasks in the world", return_tensors="pt")
# the target figurative form (<sarcasm>)
fig_ids = tokenizer.encode("<sarcasm>", add_special_tokens=False, return_tensors="pt")
outs = model.generate(input_ids=inp_ids[:, 1:], fig_ids=fig_ids, forced_bos_token_id=fig_ids.item(), num_beams=5, max_length=60,)
text = tokenizer.decode(outs[0, 2:].tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=False)
# special tokens: <literal>, <hyperbole>, <idiom>, <sarcasm>, <metaphor>, or <simile>

Training

Step 1: Pre-training

python train_pt.py -dataset ParapFG -figs hyperbole idiom metaphor sarcasm simile

Step 2: Fine-tuning

# parallel paraphrase pretraining data
python train_ft.py -dataset ParapFG -figs hyperbole idiom metaphor sarcasm simile

# literal-figurative parallel data
python train_ft.py -dataset MultiFG -figs hyperbole idiom metaphor sarcasm simile

Step 3: Figurative Generation

# Generating idioms form hyperbolic text
python inference.py -src_form hyperbole -tgt_form idiom

Model and Outputs

  • Our model mFLAG can be found in Hugging Face, the corresponding outputs are in the /data/outputs/ directory

Citation

@inproceedings{lai-etal-2022-multi,
    title = "Multi-Figurative Language Generation",
    author = "Lai, Huiyuan and Nissim, Malvina",
    booktitle = "Proceedings of the 29th International Conference on Computational Linguistics",
    month = October,
    year = "2022",
    address = "Gyeongju, Republic of korea",
}

mflag's People

Contributors

laihuiyuan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

techthiyanes

mflag's Issues

Facing an error in model inference using provided example

inp_ids = tokenizer.encode(" I am not happy that he urged me to finish all the hardest tasks in the world", return_tensors="pt")
fig_ids = tokenizer.encode("", add_special_tokens=False, return_tensors="pt")
outs = model.generate(input_ids=inp_ids[:, 1:], fig_ids=fig_ids, forced_bos_token_id=fig_ids.item(), num_beams=5, max_length=60,)
text = tokenizer.decode(outs[0, 2:].tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=False)


AttributeError Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_32792\2188357501.py in
1 inp_ids = tokenizer.encode(" I am not happy that he urged me to finish all the hardest tasks in the world", return_tensors="pt")
2 fig_ids = tokenizer.encode("", add_special_tokens=False, return_tensors="pt")
----> 3 outs = model.generate(input_ids=inp_ids[:, 1:], fig_ids=fig_ids, forced_bos_token_id=fig_ids.item(), num_beams=5, max_length=60,)
4 text = tokenizer.decode(outs[0, 2:].tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=False)

~.conda\envs\transformers\lib\site-packages\torch\autograd\grad_mode.py in decorate_context(*args, **kwargs)
26 def decorate_context(*args, **kwargs):
27 with self.class():
---> 28 return func(*args, **kwargs)
29 return cast(F, decorate_context)
30

~.conda\envs\transformers\lib\site-packages\transformers\generation_utils.py in generate(self, inputs, max_length, min_length, do_sample, early_stopping, num_beams, temperature, penalty_alpha, top_k, top_p, typical_p, repetition_penalty, bad_words_ids, force_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, logits_processor, renormalize_logits, stopping_criteria, constraints, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, exponential_decay_length_penalty, suppress_tokens, begin_suppress_tokens, forced_decoder_ids, **model_kwargs)
1338 # and added to model_kwargs
1339 model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
-> 1340 inputs_tensor, model_kwargs, model_input_name
1341 )
1342

~.conda\envs\transformers\lib\site-packages\transformers\generation_utils.py in _prepare_encoder_decoder_kwargs_for_generation(self, inputs_tensor, model_kwargs, model_input_name)
581 encoder_kwargs["return_dict"] = True
582 encoder_kwargs[model_input_name] = inputs_tensor
--> 583 model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
584
585 return model_kwargs

~.conda\envs\transformers\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []

~\Desktop\DS_Team_Tasks\NLP\mFLAG\model.py in forward(self, input_ids, attention_mask, fig_ids, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict)
228 inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
229
--> 230 embed_pos = self.embed_positions(input_shape)
231
232 hidden_states = inputs_embeds + embed_pos

~.conda\envs\transformers\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []

~.conda\envs\transformers\lib\site-packages\transformers\models\bart\modeling_bart.py in forward(self, input_ids, past_key_values_length)
132 """`input_ids' shape is expected to be [bsz x seqlen]."""
133
--> 134 bsz, seq_len = input_ids.shape[:2]
135 positions = torch.arange(
136 past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device

AttributeError: 'torch.Size' object has no attribute 'shape'

about evaluation

root@DESKTOP-Q3VVK87:/home/cschy/mflag_test# perl multi-bleu.perl test_hyperbole.1 < l2hyperbole.txt
BLEU = 51.47, 68.7/56.4/47.8/40.2 (BP=0.985, ratio=0.985, hyp_len=1332, ref_len=1352)

when I use multi-bleu.perl to calculate bleu, I find the result is lower than your paper given(the same goes for idiom, sarcasm, idiom and simile). l2hyperbole.txt is created by the output of your checkpoint. Did you do anything extra with the output file? Any help would be appreciated. Below is my generated code:

from tokenization_mflag import MFlagTokenizerFast
tokenizer = MFlagTokenizerFast.from_pretrained('laihuiyuan/mFLAG', cache_dir='author')
model = MultiFigurativeGeneration.from_pretrained('laihuiyuan/mFLAG', cache_dir='author')
with open('./data/MultiFG/test_hyperbole.0', 'r') as f, open('./data/outputs/l2hyperbole.txt', 'w') as fw:
    lines = f.readlines()
    for i in lines:
        inp_ids = tokenizer.encode(
            "<literal>" + i.strip(), return_tensors="pt")
        fig_ids = tokenizer.encode("<hyperbole>", add_special_tokens=False, return_tensors="pt")
        outs = model.generate(input_ids=inp_ids[:, 1:], fig_ids=fig_ids, forced_bos_token_id=fig_ids.item(), num_beams=5,
                              max_length=60, )
        text = tokenizer.decode(outs[0, 2:].tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=False)
        print(text)
        fw.write(text + '\n')

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.