Git Product home page Git Product logo

ea-vq-vae's Introduction

Introduction

This repo provides the code for the ACL 2020 paper "Evidence-Aware Inferential Text Generation with Vector Quantised Variational AutoEncoder"

Requirements

  • pip install torch==1.4.0

  • pip install gdown transformers==2.8.0 filelock nltk

Download Dataset

1.Download Evidence

cd data
gdown https://drive.google.com/uc?id=1l8o0Itcr-MqKAdMxELSWGd1TnEF8eyXu
cd ..

Or you can download searched evidence from website to data folder.

2.Download and Preprocess ATOMIC Datasets

cd data
bash get_atomic_data.sh
python preprocess-atomic.py
cd ..

3.Download and Preprocess Event2Mind Datasets

cd data
bash get_event2mind_data.sh
python preprocess-event2mind.py
cd ..

Train Vector Quantised-Variational AutoEncoder (VQ-VAE)

We first train VQ-VAE with the posterior distribution p(z|x,y).

cd vq-vae
task=event2mind #event2mind or atomic
train_steps=20000 #20000 for event2mind and 50000 for atomic
mkdir -p log model/$task
CUDA_VISIBLE_DEVICES=0,1,2,3 python run.py \
--model_name_or_path gpt2 \
--data_dir ../data/$task \
--output_dir model/$task \
--do_train \
--z_size 400 \
--max_event_length 64 \
--max_target_length 32 \
--train_batch_size 64 \
--eval_batch_size 128 \
--eval_steps 1000 \
--learning_rate 5e-5 \
--train_steps $train_steps \
--gradient_accumulation_steps 1 2>&1 | tee log/log-$task-train.txt

We then calculate true prior distribution of train and dev dataset.

CUDA_VISIBLE_DEVICES=0,1,2,3 python run.py \
--model_name_or_path gpt2 \
--load_model_path model/$task/pytorch_model.bin \
--data_dir ../data/$task \
--output_dir model/$task \
--do_label \
--z_size 400 \
--max_event_length 64 \
--max_target_length 32 \
--eval_batch_size 128 2>&1 | tee log/log-$task-test.txt

Train Prior Distribution Estimator

We then train prior distribution estimator p(z|x).

cd ../estimator
task=event2mind #event2mind or atomic
train_steps=10000 #10000 for event2mind and 20000 for atomic
mkdir -p log model/$task
CUDA_VISIBLE_DEVICES=0,1,2,3 python run.py \
--model_name_or_path roberta-large \
--prior_distribution_dir ../vq-vae/model/$task \
--data_dir ../data/$task \
--output_dir model/$task \
--do_train \
--z_size 400 \
--max_event_length 64 \
--train_batch_size 32 \
--eval_batch_size 64 \
--eval_steps 1000 \
--learning_rate 1e-5 \
--train_steps $train_steps \
--gradient_accumulation_steps 1 2>&1 | tee log/log-$task-train.txt

We then calculate approximate posterior distribution of train, dev and test dataset.

CUDA_VISIBLE_DEVICES=0,1,2,3 python run.py \
--model_name_or_path roberta-large \
--prior_distribution_dir ../vq-vae/model/$task \
--load_model_path model/$task/pytorch_model.bin \
--data_dir ../data/$task \
--output_dir model/$task \
--do_label \
--z_size 400 \
--max_event_length 64 \
--eval_batch_size 128 2>&1 | tee log/log-$task-test.txt

Train Evidence-Aware Decoder

We finally jointly learn the context distribution p(c|z) and the generator p(y|x,c)

cd ../generator
task=event2mind #event2mind or atomic
train_steps=20000 #20000 for event2mind and 50000 for atomic
num_evidence=20
mkdir -p log model/$task
CUDA_VISIBLE_DEVICES=0,1,2,3 python run.py \
--model_name_or_path gpt2 \
--data_dir ../data/$task \
--codebook_path ../vq-vae/model/$task/codebook.bin \
--posterior_dir ../vq-vae/model/$task \
--prior_dir ../estimator/model/$task \
--do_train \
--z_size 400 \
--output_dir model/$task \
--max_evidence_length 64 \
--max_event_length 64 \
--max_target_length 32 \
--num_evidence $num_evidence \
--eval_steps 1000 \
--train_batch_size 64 \
--eval_batch_size 128 \
--learning_rate 5e-5 \
--train_steps $train_steps \
--gradient_accumulation_steps 1 2>&1 | tee log/log-$task-train.txt

We then obtain topK latent variables for selecting evidences

CUDA_VISIBLE_DEVICES=0,1,2,3 python run.py \
--model_name_or_path gpt2 \
--load_model_path model/$task/pytorch_model.bin \
--data_dir ../data/$task \
--codebook_path ../vq-vae/model/$task/codebook.bin \
--posterior_dir ../vq-vae/model/$task \
--prior_dir ../estimator/model/$task \
--do_topk \
--z_size 400 \
--output_dir model/$task \
--max_evidence_length 64 \
--max_event_length 64 \
--max_target_length 32 \
--num_evidence $num_evidence \
--eval_batch_size 128  2>&1 | tee log/log-$task-topk.txt

Using topK latent variable to select evidences for inference

CUDA_VISIBLE_DEVICES=0,1,2,3 python run.py \
--model_name_or_path gpt2 \
--load_model_path model/$task/pytorch_model.bin \
--data_dir ../data/$task \
--codebook_path ../vq-vae/model/$task/codebook.bin \
--posterior_dir ../vq-vae/model/$task \
--prior_dir ../estimator/model/$task \
--do_test \
--z_size 400 \
--output_dir model/$task \
--max_evidence_length 64 \
--max_event_length 64 \
--max_target_length 32 \
--num_evidence $num_evidence \
--eval_batch_size 128  2>&1 | tee log/log-$task-test.txt

Cite

If you find our code useful, please consider citing our paper:

@inproceedings{guo-etal-2020-evidence,
    title = "Evidence-Aware Inferential Text Generation with Vector Quantised Variational {A}uto{E}ncoder",
    author = "Guo, Daya  and
      Tang, Duyu  and
      Duan, Nan  and
      Yin, Jian  and
      Jiang, Daxin  and
      Zhou, Ming",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
    month = jul,
    year = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url = "https://www.aclweb.org/anthology/2020.acl-main.544",
    pages = "6118--6129",
}

Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.

ea-vq-vae's People

Contributors

guody5 avatar microsoft-github-operations[bot] avatar microsoftopensource avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

ea-vq-vae's Issues

Why does the vae_loss keep growing up when training the vqvae?

Thanks for your code! I've got a question when I run your code on the event2mind dataset. I find the re_loss drop as expected, but the vae_loss drops at the very beginning, from 1.2431 to 1.2353, but keeps growing up afterwards. Is this normal?
If the vae_loss means the distance between the encoder output and the corresponding embedding in the codebook, then does this mean the distance is becoming longer during training?

Bug in beam search in generator

In the file generator/beam.py in method advance, it seems that the .prevK and .nexY tensor are floats and not tensors. I put some debug code and found below...

Debug code:

        for i in range(self.nextYs[-1].size(0)):
            try:
              self.scores[i,:]=beamLk[self.prevKs[-1][i],:,self.nextYs[-1][i]]
            except:
              print("problems with beam search", self.prevKs[-1][i],self.nextYs[-1][i])
            if self.nextYs[-1][i] == self._eos:
                s = torch.exp(self.scores[i]).sum()
                self.finished.append((s, len(self.nextYs) - 1, i))

/usr/local/lib/python3.6/dist-packages/transformers/modeling_gpt2.py:536: FutureWarning: The past argument is deprecated and will be removed in a future version, use past_key_values instead.
FutureWarning,
problems with beam search tensor(0.0016, device='cuda:0') tensor(0., device='cuda:0')
problems with beam search tensor(0.0943, device='cuda:0') tensor(0.0005, device='cuda:0')
problems with beam search tensor(0.0129, device='cuda:0') tensor(0., device='cuda:0')
problems with beam search tensor(0.0282, device='cuda:0') tensor(0., device='cuda:0')
problems with beam search tensor(0.0014, device='cuda:0') tensor(0., device='cuda:0')
problems with beam search tensor(0.0015, device='cuda:0') tensor(0., device='cuda:0')
problems with beam search tensor(0.0246, device='cuda:0') tensor(0., device='cuda:0')
problems with beam search tensor(0.0431, device='cuda:0') tensor(0., device='cuda:0')
problems with beam search tensor(0.0017, device='cuda:0') tensor(7.6294e-06, device='cuda:0')
problems with beam search tensor(0.1270, device='cuda:0') tensor(0., device='cuda:0')

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.