Git Product home page Git Product logo

dialogved's Introduction

DialogVED

Code and released pre-trained model for our ACL 2022 paper: DialogVED: A Pre-trained Latent Variable Encoder-Decoder Model for Dialog Response Generation.

News

  • Fixed bugs in dailydialog, updated new training and evaluation scripts. (2022.06.19)
  • Optimize code structure and remove redundant code. (2022.05.29)
  • Pretrained checkpoints of DialogVED have been released! (2022.05.17)

TODO

  • A fp16 version of DialogVED will be released, about 700M in size.
  • Pre-trained scripts are scheduled to be released.

Requirements

  • python==3.7
  • torch==1.3.0
  • fairseq==0.9.0
  • tensorboardX==1.7
  • pytorch_transformers
  • sklearn
  • nltk==3.5
sudo apt install default-jdk
curl https://install.meteor.com/ | sh

pip install -r requirements.txt

Pre-trained Models

We have released the following checkpoints for pre-trained models as described in the paper of DialogVED. Download the pre-trained checkpoint and set the load-from-pretrained-model parameter in the fine-tuning running command.

Note: DialogVED-VAE-Standard has a size of latent size 32, where DialogVED-VAE-Large has a size of latent size 64. DialogVED-Seq2Seq has no latent variable, it's a pure seq2seq model with the same training setting like DialogVED. It may perform better in scenarios where diversity of responses is less important.

Fine-tuning on your own dialogue datasets!

Data preparation

Prepare your train.src, train.tgt, dev.src, dev.tgt, test.src, test.tgt as follows, context and response of one dialogue sample are placed in the .src and .tgt file with one line. Use '[SEP] to separate different turns or to separate session and knowledge to feed input texts into the encoder, predict the response from the decoder.

from pytorch_transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')    

sep = ' [SEP] '
context = [
    'So how did I do on my driving test ?', 
    'Do you want the truth ?', 
    'Of course , I do .' 
]
response = 'Well , you really did not do all that well .'

tokenized_context = sep.join([' '.join(tokenizer.tokenize(sent)) for sent in context])
tokenized_response = tokenizer.tokenize(response)

fin = 'train.src'
fout = 'train.tgt'
with open(fin, 'w', encoding='utf-8') as f:
    f.write(tokenized_context + '\n')
with open(fin, 'w', encoding='utf-8') as f:
    f.write(tokenized_response + '\n')

Binirization

PROJECT_PATH=/remote-home/wchen/project/DialogVED

USER_DIR=${PROJECT_PATH}/src
VOCAB_PATH=${PROJECT_PATH}/vocab.txt
NUM_WORKERS=20

DATA_DIR=YourDatasetDir
PROCESSED_DIR=${DATA_DIR}/processed    # put train.src, train.tgt, dev.src, dev.tgt, test.src, test.tgt here
BINARY_DIR=${DATA_DIR}/binary          # binarized dir
TASK=translation_prophetnet            # 

fairseq-preprocess \
  --fp16 \
  --user-dir ${USER_DIR} \
  --task ${TASK} \
  --source-lang src \
  --target-lang tgt \
  --trainpref ${PROCESSED_DIR}/train \
  --validpref ${PROCESSED_DIR}/valid \
  --testpref ${PROCESSED_DIR}/test \
  --destdir ${BINARY_DIR} \
  --srcdict ${VOCAB_PATH} \
  --tgtdict ${VOCAB_PATH} \
  --workers ${NUM_WORKERS}

Train

Note: If your device does not support float16, remove --fp16. If the GPU memory of your device is small and cannot support the default batch size, please remember to reduce the learning rate appropriately, or it will not converge normally.

PRETRAINED_MODEL_PATH='/remote-home/wchen/models/dialogved_large.pt'

PROJECT_PATH='/remote-home/wchen/project/DialogVED'
ARCH=ngram_transformer_prophet_vae_large
NUM_WORKERS=10
CRITERION=ved_loss
TASK=translation_prophetnet
USER_DIR=${PROJECT_PATH}/src
DATA_DIR=YourDatasetDir
SAVE_DIR=${DATA_DIR}/checkpoints
TB_LOGDIR=${DATA_DIR}/tensorboard


fairseq-train \
  ${DATA_DIR}/binary \
  --fp16 \
  --user-dir ${USER_DIR} --task ${TASK} --arch ${ARCH} \
  --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 1.0 \
  --lr 0.0003 \
  --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 2000 \
  --criterion $CRITERION --label-smoothing 0.1 \
  --update-freq 4 --max-tokens 4500 --max-sentences 16 \
  --num-workers ${NUM_WORKERS}  \
  --dropout 0.1 --attention-dropout 0.1 --activation-dropout 0.0 --weight-decay 0.01 \
  --encoder-layer-drop 0.0 \
  --save-dir ${SAVE_DIR} \
  --max-epoch 10 \
  --keep-last-epochs 10 \
  --max-source-positions 512 \
  --max-target-positions 128 \
  --kl-loss-weight 1.0 \
  --target-kl 5.0 \
  --cls-bow-loss-weight 0.0 \
  --latent-bow-loss-weight 1.0 \
  --masked-lm-loss-weight 0.0 \
  --tensorboard-logdir ${TB_LOGDIR} \
  --dataset-impl mmap \
  --empty-cache-freq 64 \
  --seed 1 \
  --skip-invalid-size-inputs-valid-test \
  --distributed-no-spawn \
  --ddp-backend no_c10d \
  --load-from-pretrained-model "${PRETRAINED_MODEL_PATH}"

Inference

Inference with fairseq-generate to generate targets for given processed test files.

BEAM=5
LENPEN=1.0
DATA_DIR=YourDatasetDir
CHECK_POINT=${SAVE_DIR}/checkpoint8.pt
OUTPUT_FILE=${DATA_DIR}/output.txt
PRED_FILE=${DATA_DIR}/pred.txt  # this the final prediction results
TASK=translation_prophetnet

fairseq-generate "${DATA_DIR}"/binary \
  --path "${CHECK_POINT}" \
  --user-dir ${USER_DIR} \
  --task ${TASK} \
  --batch-size 64 \
  --gen-subset test \     
  --beam ${BEAM} \
  --num-workers 4 \
  --no-repeat-ngram-size 3 \
  --lenpen ${LENPEN} \
  2>&1 >"${OUTPUT_FILE}"
grep ^H "${OUTPUT_FILE}" | cut -c 3- | sort -n | cut -f3- | sed "s/ ##//g" > "${PRED_FILE}"

Fine-tuning on DailyDialog, PersonaChat and DSTC7AVSD

Data preparation

We finetune DialogVED on three datasets DailyDialog, PersonaChat and DSTC7AVSD. You can download them according to the instructions in PLATO, or run our script as follows.

bash preprocess/get_data.sh

Preprocess

bash preprocess/process.sh

Binarization

bash preprocess/binarize.sh

Training

the script train.sh has three parameters, namely p, t and d.

  • p: pretrained model path
  • t: pretrained model type (dialogved_standard, dialogved_large or dialogved_seq2seq)
  • d: fine-tuned dataset (dailydialog, personachat or dstc7avsd)

Note: According to the feedback of some developers, if the GPU memory of your device is small and cannot support the default batch size, please reduce the learning rate appropriately, or it will not converge normally.

bash train.sh -p /remote-home/models/dialogved_standard.pt -t dialogved_standard -d dailydialog

Inference

the script infer.sh has two parameters, namely d and s.

  • d: fine-tuned dataset (dailydialog, personachat or dstc7avsd)
  • s: decoding strategy (greedy, beam or sampling)
bash infer.sh -d dailydialog -s beam

Evaluation

the script eval.sh has one parameter, namely d.

  • d: fine-tuned dataset (dailydialog, personachat or dstc7avsd)
bash eval.sh -d dailydialog

Reddit dataset for pre-training

The original Reddit data for pre-training has been shared on Baidu's online disk.

Link: https://pan.baidu.com/s/1--K9DiPtsSStV7yKQPyc7A Extraction Code: 8grj

How to Cite

If you extend or use this work, please cite the paper where it was introduced:

@inproceedings{chen-etal-2022-dialogved,
    title = "{DialogVED: A Pre-trained Latent Variable Encoder-Decoder Model for Dialog Response Generation",
    author = "Chen, Wei and Gong, Yeyun and Wang, Song and Yao, Bolun and Qi, Weizhen and Wei, Zhongyu and Hu, Xiaowu and Zhou, Bartuer and Mao, Yi and Chen, Weizhu and Cheng, Biao and Duan, Nan",
    booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
    month = may,
    year = "2022",
    address = "Dublin, Ireland",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2022.acl-long.333",
    doi = "10.18653/v1/2022.acl-long.333",
    pages = "4852--4864",
    abstract = "Dialog response generation in open domain is an important research topic where the main challenge is to generate relevant and diverse responses. In this paper, we propose a new dialog pre-training framework called DialogVED, which introduces continuous latent variables into the enhanced encoder-decoder pre-training framework to increase the relevance and diversity of responses. With the help of a large dialog corpus (Reddit), we pre-train the model using the following 4 tasks, used in training language models (LMs) and Variational Autoencoders (VAEs) literature: 1) masked language model; 2) response generation; 3) bag-of-words prediction; and 4) KL divergence reduction. We also add additional parameters to model the turn structure in dialogs to improve the performance of the pre-trained model. We conduct experiments on PersonaChat, DailyDialog, and DSTC7-AVSD benchmarks for response generation. Experimental results show that our model achieves the new state-of-the-art results on all these datasets.",
}

dialogved's People

Contributors

lemuria-wchen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

dialogved's Issues

Could you check the link for DialogVED-VAE-Large

Hi, thank you for sharing such a great work.
The current version of DialogVED-VAE-Large seems to not work.
There is a mismatch for dimension for the latent.
Could you double check if you have upload correct model?

Thank you,
Jin

pre-trained checkpoint download link

I have met some difficulity in downloading the checkpoint of pre-trained model. So, I would like to ask if there could be any access of download link of [dialogved_standard.pt] in Baidu Net Disk.

Question about K-L term objective

Thanks for your contribution, this novel VED form pretrained dialog model is very helpful !
After reading your reasearch, I have a question about K-L regularization term.
Most variational generation in dialog models use CVAE objective, they have prior and posterior network and use K-L term $KL (q(z|c,r) || p(z|c) )$
But in your work, this posterior $q(z|c,r)$ is replaced by $q(z)$.
May I ask why do you use this replacement ?

windows系统分布式训练问题

您好,我是用单机多卡进行分布式训练,但是报了这个错Runtimeerror: no rendezvous handler for TCP: / /,请问应该修改哪里才能解决呢

Pretraining dataset

I find some problems to obtain the pretraining reddit datasets. Could u release the pretraining datasets to help me to re-pretrain DialogVED? Thanks a lot.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.