Git Product home page Git Product logo

g-transformer's Introduction

G-Transformer

This code is for ACL 2021 paper G-Transformer for Document-level Machine Translation.

Python Version: Python3.6

Package Requirements: torch==1.4.0 tensorboardX numpy==1.19.0

Framework: Our model and experiments are built upon fairseq. We use a snapshot version between 0.9.0 and 1.10.0 as our initial code.

Before running the scripts, please install fairseq dependencies by:

    pip install --editable .

Please also follow the readmes under folder raw_data and mbart.cc25 to download raw data and pretrained model. (Notes: Our models were trained on 4 GPUs. If you trained them on 2 GPUs, in theory you could double the number for argument --update-freq. However, we haven't tested such settings.)

Non-pretraining Settings

G-Transformer random initialized

  • Prepare data:
    mkdir exp_randinit
    bash exp_gtrans/run-all.sh prepare-randinit exp_randinit
  • Train model:
    CUDA_VISIBLE_DEVICES=0,1,2,3 bash exp_gtrans/run-all.sh run-randinit train exp_randinit
  • Evaluate model:
    bash exp_gtrans/run-all.sh run-randinit test exp_randinit

G-Transformer fine-tuned on sent Transformer

  • Prepare data:
    mkdir exp_finetune
    bash exp_gtrans/run-all.sh prepare-finetune exp_finetune
  • Train model:
    CUDA_VISIBLE_DEVICES=0,1,2,3 bash exp_gtrans/run-all.sh run-finetune train exp_finetune
  • Evaluate model:
    bash exp_gtrans/run-all.sh run-finetune test exp_finetune

Pretraining Settings

G-Transformer fine-tuned on mBART25

  • Prepare data:
    mkdir exp_mbart
    bash exp_gtrans/run-all.sh prepare-mbart exp_mbart
  • Train model:
    CUDA_VISIBLE_DEVICES=0,1,2,3 bash exp_gtrans/run-all.sh run-mbart train exp_mbart
  • Evaluate model:
    bash exp_gtrans/run-all.sh run-mbart test exp_mbart

g-transformer's People

Contributors

baoguangsheng avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

g-transformer's Issues

The BLEU value

May I ask whether it is tokenized BLEU or detokenized BLEU in G-transformer's paper?

How to expand the input length to 2048 tokens?

image

I modified "--max-tokens" to 2048 in "prepare-randinit.sh" to prepare all data. Then I started to train, but there is an error about "max_positions". Do you know how to fix this problem?

Furthermore, is there anything else I should pay attention to?

Acually, I want to know when the input length is expanded to 2048 tokens, d-BLEU will be declining? Have you ever done an experiment like this?What is the current maximum input length(input tokens) given that d-BLEU is stable?

一些关于baseline的问题

您好,如果我想复现您论文中的strong baseline-Transformer on sent (baseline) 上面的结果,该如何设置arch、task等fairseq-train中的参数?
此外,该baseline和SentNMT有什么区别?

加载checkpoint_last.pt会报错

  File "train.py", line 14, in <module>
    cli_main()
  File "/data1/mzlv/g-transformer/fairseq_cli/train.py", line 347, in cli_main
    cli_main_helper(args)
  File "/data1/mzlv/g-transformer/fairseq_cli/train.py", line 374, in cli_main_helper
    fn=distributed_main, args=(args,), nprocs=args.distributed_world_size
  File "/home/mzlv/anaconda3/envs/gtrans_test_cu10/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 171, in spawn
    while not[ spawn_context.join():](https://wx.qq.com/cgi-bin/mmwebwx-bin/webwxcheckurl?requrl=http%3A%2F%2Fspawn_context.join()%3A&skey=%40crypt_fe4869f3_6be60a4be54a35ffb40698e1efe786d6&deviceid=e134100181037250&pass_ticket=undefined&opcode=2&scene=1&username=@b86617e379dbf55d07952dadda5d9f6098165536321e99eb1687bed0ccd16120)
  File "/home/mzlv/anaconda3/envs/gtrans_test_cu10/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 118, in join
    raise Exception(msg)
Exception: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/home/mzlv/anaconda3/envs/gtrans_test_cu10/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/data1/mzlv/g-transformer/fairseq_cli/train.py", line 336, in distributed_main
    main(args, init_distributed=True)
  File "/data1/mzlv/g-transformer/fairseq_cli/train.py", line 111, in main
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
  File "/data1/mzlv/g-transformer/fairseq/checkpoint_utils.py", line 138, in load_checkpoint
    strict=not getattr(args, "load_partial", False)
  File "/data1/mzlv/g-transformer/fairseq/trainer.py", line 334, in load_checkpoint
    self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
  File "/data1/mzlv/g-transformer/fairseq/optim/fairseq_optimizer.py", line 72, in load_state_dict
    self.optimizer.load_state_dict(state_dict)
  File "/home/mzlv/anaconda3/envs/gtrans_test_cu10/lib/python3.6/site-packages/torch/optim/optimizer.py", line 111, in load_state_dict
    raise ValueError("loaded state dict has a different number of "
ValueError: loaded state dict has a different number of parameter groups

篇章级微调的时候加载句子级模型是正常的。如果微调中断,加载checkpoint_last.pt会报以上错误,是否g-transformer微调的时候中断只能从头训练,还是有额外的参数需要指定?

I can't use scripts for the baseline of doc transformer

Hi author,I can't use scripts for the baseline of doc transformer. After cloning the new fairseq and copied the baseline scripts to fairseq root folder. but fairseq/models/transformer.py does not exist in the newly downloaded fairseq folder. so I just appending the model setting to fairseq/models/transformer/transformer_legacy.py. Then I get an error. Could you please explain what I am missing or what is wrong? Thanks.

G-Transformer+BERT setting

Hi author, can you share more details of the G-Transformer+BERT experimental settings? e.g., which version of the pre-trained BERT model (base or large, cased or uncased) was used, as well as the BERT decoding during inference to compute the s-BLEU scores. Thanks.

Evaluation

after training sent and doc models, i miss this error.
Traceback (most recent call last): File "/home/dell/anaconda3/envs/g-transformer/lib/python3.7/runpy.py", line 193, in _run_module_as_main "__main__", mod_spec) File "/home/dell/anaconda3/envs/g-transformer/lib/python3.7/runpy.py", line 85, in _run_code exec(code, run_globals) File "/home/dell/桌面/g-transformer/fairseq_cli/generate.py", line 356, in <module> cli_main() File "/home/dell/桌面/g-transformer/fairseq_cli/generate.py", line 352, in cli_main main(args) File "/home/dell/桌面/g-transformer/fairseq_cli/generate.py", line 40, in main return _main(args, sys.stdout) File "/home/dell/桌面/g-transformer/fairseq_cli/generate.py", line 169, in _main hypos = task.inference_step(generator, models, sample, prefix_tokens) File "/home/dell/桌面/g-transformer/fairseq/tasks/fairseq_task.py", line 358, in inference_step return generator.generate(models, sample, prefix_tokens=prefix_tokens) File "/home/dell/anaconda3/envs/g-transformer/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "/home/dell/桌面/g-transformer/fairseq/sequence_generator.py", line 176, in generate return self._generate(sample, **kwargs) File "/home/dell/桌面/g-transformer/fairseq/sequence_generator.py", line 331, in _generate _fill_frame(tokens, src_tokens) File "/home/dell/桌面/g-transformer/fairseq/sequence_generator.py", line 283, in _fill_frame docfrm = _align_frame(src[i]) File "/home/dell/桌面/g-transformer/fairseq/sequence_generator.py", line 267, in _align_frame assert len(frame) % 3 == 0 AssertionError

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.