Git Product home page Git Product logo

mintl's Introduction

MinTL: Minimalist Transfer Learning for Task-Oriented Dialogue Systems

License: MIT

This is the implementation of the EMNLP 2020 paper:

MinTL: Minimalist Transfer Learning for Task-Oriented Dialogue Systems. Zhaojiang Lin, Andrea Madotto, Genta Indra Winata, Pascale Fung [PDF]

Citation:

If you use any source codes or datasets included in this toolkit in your work, please cite the following paper. The bibtex is listed below:

@article{lin2020mintl,
    title={MinTL: Minimalist Transfer Learning for Task-Oriented Dialogue Systems},
    author={Zhaojiang Lin and Andrea Madotto and Genta Indra Winata and Pascale Fung},
    journal={arXiv preprint arXiv:2009.12005},
    year={2020}
}

Abstract:

In this paper, we propose Minimalist Transfer Learning (MinTL) to simplify the system design process of task-oriented dialogue systems and alleviate the over-dependency on annotated data. MinTL is a simple yet effective transfer learning framework, which allows us to plug-and-play pre-trained seq2seq models, and jointly learn dialogue state tracking and dialogue response generation. Unlike previous approaches, which use a copy mechanism to "carryover" the old dialogue states to the new one, we introduce Levenshtein belief spans (Lev), that allows efficient dialogue state tracking with a minimal generation length. We instantiate our learning framework with two pretrained backbones: T5 (Raffel et al., 2019) and BART (Lewis et al., 2019), and evaluate them on MultiWOZ. Extensive experiments demonstrate that: 1) our systems establish new state-of-the-art results on end-to-end response generation, 2) MinTL-based systems are more robust than baseline methods in the low resource setting, and they achieve competitive results with only 20% training data, and 3) Lev greatly improves the inference efficiency.

Dependency

Check the packages needed or simply run the command

❱❱❱ pip install -r requirements.txt

Experiments Setup

We used the preprocess script from DAMD. Please check setup.sh for data preprocessing.

Experiments

T5 End2End

❱❱❱ python train.py --mode train --context_window 2 --pretrained_checkpoint t5-small --cfg seed=557 batch_size=32

T5 DST

❱❱❱ python DST.py --mode train --context_window 3 --cfg seed=557 batch_size=32

BART End2End

❱❱❱ python train.py --mode train --context_window 2 --pretrained_checkpoint bart-large-cnn --gradient_accumulation_steps 8 --lr 3e-5 --back_bone bart --cfg seed=557 batch_size=8

BART DST

❱❱❱ python DST.py --mode train --context_window 3 --gradient_accumulation_steps 10 --pretrained_checkpoint bart-large-cnn --back_bone bart --lr 1e-5 --cfg seed=557 batch_size=4

check run.py for more information.

mintl's People

Contributors

zlinao avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

mintl's Issues

Evaluation and dialog state

Hi, first thank you for your work!

I am trying to understand the way you evaluate models (especially the inform and success rates on MultiWOZ) when using the MinTL framework. For generating a response, the model has to encode the previous dialogue state and the context, and predict the state update which is combined with the original state to form a new one. This new dialog state is used for querying the database etc., right?

During the evaluation (these lines?), is the model given the ground truth belief state from the previous turn, or does it use the "cummulative" one that was predicted in the previous turns of the particular conversation?

I see some problems in both cases. When using the ground-truth belief state from the previous turn, the metrics might be overestimated. On the other hand, when using the fully-predicted last state, the ground-truth user response is used and it might not be consistent with the previous state. So I would actually expect the metrics to be underestimated, am I right?

Thank you in advance 🙂

Unable to train: numpy.AxisError: axis 1 is out of bounds for array of dimension 1

Hello again 🙂

I am trying to run the code of this repository, but I am not successful. I installed all requirements, updated & run setup.sh.
I would like to train the t5-small, so I use the suggested command:

python train.py --mode train --context_window 2 --pretrained_checkpoint t5-small --cfg seed=557 batch_size=32

However, it fails at start saying:

utils.py:448: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
Traceback (most recent call last):
  File "train.py", line 359, in <module>
    main()
  File "train.py", line 348, in main
    m.train()
  File "train.py", line 67, in train
    inputs = self.reader.convert_batch(turn_batch, py_prev, first_turn=first_turn, dst_start_token=self.model.config.decoder_start_token_id)
  File "utils.py", line 448, in convert_batch
    inputs["response_input"] = torch.tensor( np.concatenate( ( np.array(batch['input_pointer']), response_input[:,:-1]), axis=1 ) ,dtype=torch.long)
  File "<__array_function__ internals>", line 6, in concatenate
numpy.AxisError: axis 1 is out of bounds for array of dimension 1

The first array has shape (32,), the second (32, 40).

These are my installed packages:

blis==0.4.1      
certifi==2020.12.5
chardet==4.0.0          
click==7.1.2                           
cymem==2.0.5   
en-core-web-sm==2.2.5
filelock==3.0.12   
idna==2.10                                                                                                 
importlib-metadata==3.10.0
joblib==1.0.1                                                                                                                                                                                                      
murmurhash==1.0.5                                                                                                                                                                                                  
nltk==3.4.5                                                                                                                                                                                                        
numpy==1.20.2                                                                                                                                                                                                      
packaging==20.9                                                                                                                                                                                                    
plac==1.1.3                                                                                                                                                                                                        
preshed==3.0.5                                                                                                                                                                                                     
pyparsing==2.4.7                                                                                                                                                                                                   
regex==2021.3.17                                                                                                                                                                                                   
requests==2.25.1                                                                                                                                                                                                   
sacremoses==0.0.43                                                                                                                                                                                                 
sentencepiece==0.1.95                                                                                                                                                                                              
six==1.15.0                                                                                                                                                                                                        
spacy==2.2.2                                                                                                                                                                                                       
srsly==1.0.5                                                                                                                                                                                                       
thinc==7.3.1                                                                                                                                                                                                       
tokenizers==0.10.1                                                                                                                                                                                                 
torch==1.4.0                                                                                                                                                                                                       
tqdm==4.59.0                                                                                                                                                                                                       
transformers==4.4.2                                                                                                                                                                                                
typing-extensions==3.7.4.3                                                                                                                                                                            
urllib3==1.26.4                                                                                               
wasabi==0.8.2   
zipp==3.4.1

Am I missing something?

AttributeError: 'str' object has no attribute 'size'

thank you for this work.

During train the model, I got this error

Traceback (most recent call last):
  File "/content/drive/MyDrive/MinTL/train.py", line 388, in <module>
    main()
  File "/content/drive/MyDrive/MinTL/train.py", line 377, in main
    m.train()
  File "/content/drive/MyDrive/MinTL/train.py", line 130, in train
    lm_labels=inputs["response"]
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/drive/MyDrive/MinTL/T5.py", line 69, in forward
    head_mask=head_mask,
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/transformers/models/t5/modeling_t5.py", line 924, in forward
    encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
AttributeError: 'str' object has no attribute 'size'

what can I do to solve it ? please

AttributeError: 'SimBART' object has no attribute 'shared' (End2End Bart)

I tried to execute the code with BART using MBartForConditionalGeneration

So
In Bart.py I removed the lines

class MiniBART(MBartModel):
     def __init__(self, config):
         super().__init__(config)
         self.dst_decoder = type(self.decoder)(config, self.shared)
         self.dst_decoder.load_state_dict(self.decoder.state_dict())
   def tie_decoder(self):
         self.shared.padding_idx = self.config.pad_token_id
         self.dst_decoder = type(self.decoder)(self.config, self.shared)
         self.dst_decoder.load_state_dict(self.decoder.state_dict())

and used these lines instead


class SimBART(MBartForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
    def tie_decoder(self):
        pass

but I got an error when run with
python train.py --mode train --context_window 2 --pretrained_checkpoint facebook/mbart-large-50 --gradient_accumulation_steps 8 --lr 3e-5 --back_bone bart --cfg seed=557 batch_size=8

the error is

Traceback (most recent call last):
  File "C:\Users\E\train.py", line 363, in <module>
    main()
  File "C:\Users\E\train.py", line 351, in main
    m = Model(args)
  File "C:\Users\train.py", line 35, in __init__
    self.model =  SimBART.from_pretrained(args.model_path if test else 'facebook/mbart-large-50')
  File "C:\Users\E\miniconda3\envs\torch\lib\site-packages\transformers\modeling_utils.py", line 1224, in from_pretrained
    model.tie_weights()
  File "C:\Users\E\miniconda3\envs\torch\lib\site-packages\transformers\modeling_utils.py", line 522, in tie_weights
    output_embeddings = self.get_output_embeddings()
  File "C:\Users\E\BART.py", line 193, in get_output_embeddings
    return _make_linear_from_emb(self.shared)  # make it on the fly
  File "C:\Users\E\miniconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py", line 947, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'SimBART' object has no attribute 'shared'

what can I do to fix this error?

MultiWoZ 2.1

Can you please provide the data files for running on MultiWoZ 2.1, or the preprocessing scripts to generate annotated_user_da_with_span_full.json.zip

Hyper-parameters for reproducing

Hello, thanks for your amazing work ^-^
I tried to reproduce your experiment result shown in paper, using the end-2-end setting shown in run.py
image
The result I got is about 3 point lower than the result in paper.
I followed run.py to run my experiment.
I use python3.6, 1 * V100, transformers==2.8.0, the other python packages are the same as requirements.txt

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.