Git Product home page Git Product logo

casa-dialogue-act-classifier's Introduction

CASA-Dialogue-Act-Classifier

PyTorch implementation of the paper Dialogue Act Classification with Context-Aware Self-Attention for dialogue act classification with a generic dataset class and PyTorch-Lightning trainer. This implementation has following differences compare to the actual paper

  • In this implementation contextualized embedding (ie: BERT, RoBERta, etc ) (freezed hence not trainable) is used while paper uses combination of GloVe and ELMo.
  • This implementation has simple softmax classifier but paper has CRF classifier.

To train this on switchboard dialogue act dataset:

  1. Navigate to data/ using: cd data/
  2. Unzip the dataset: unzip switchboard.zip
  3. Navigate to the main dir: cd ..
  4. Install the dependencies in a separate python environment.
  5. [Optional] Change the project_name and run_name in the logger or disable the wandb logger if you don't want to use it by commenting the logger code (line 15-20 in main.py) and don't pass it to Lightning trainer (line 32 in main.py), and then comment the logging code in Trainer.py (line 70 and 95). By default Lightning will log to tensorboard logger.
  6. [Optional] Change the parameters (batch_size, lr, epochs etc) in config.py.
  7. Run main.py using python main.py
  8. Model will be trained and best checkpoint will be saved.

To train this on any dialogue act dataset

  1. Paste your data into data/, your dataset should have following structure
    • dataset_name
      • dataset_name_train.csv
      • dataset_name_valid.csv
      • dataset_name_test.csv
  2. [Optional] If you don't have separate test and validation data, copy the test/valid and rename it as valid/test, this both validation and test data will be same.
  3. Update the num_classes param in config.py line 18 according to your dataset.
  4. Follow from steps 5 of the switchboard.

Note: Feel free to create to an issue if you find any problem. Also you're welcome to create PR if you want to add something. Here is the list of components one can add:

  • Hyperparameter Search
  • More dialogue act classification models which are not open-sourced.

References

[1]: Raheja, V., & Tetreault, J. (2019). Dialogue Act Classification with Context-Aware Self-Attention. ArXiv, abs/1904.02594.

[2]: Lin, Z., Feng, M., Santos, C.D., Yu, M., Xiang, B., Zhou, B., & Bengio, Y. (2017). A Structured Self-attentive Sentence Embedding. ArXiv, abs/1703.03130.

[3]: Switchboard Dialogue Act corpus: http://compprag.christopherpotts.net/swda.html

casa-dialogue-act-classifier's People

Contributors

glicerico avatar macabdul9 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

casa-dialogue-act-classifier's Issues

Suspicious labels

Hey @macabdul9 , I realize you used the act_label_1 column in the SwDA data that you share in your repo as labels for training.
That column doesn't seem particularly good as labels, as one can see from pairs obtained from the first rows in the test data:
"Okay." - Other
"I guess" - Info-request:Yes-No-Question
"What kind of experience do you, do you have, then with child care ?" - Other:Segment-(multi-utterance)
These classes don't match with the SwDA classes, I am not sure how they were obtained.

On the other hand, the column act_tag is not a good option either, as it contains 276 different classes. I think the data needs some cleaning.

Pass callback instances to the `callbacks` argument in the Trainer constructor instead.

File "c:\Users\user1\Documents\repo\DAC\main.py", line 54, in
trainer = pl.Trainer(
File "C:\Users\user1\AppData\Local\Programs\Python\Python39\lib\site-packages\pytorch_lightning\trainer\connectors\env_vars_connector.py", line 40, in insert_env_defaults
return fn(self, **kwargs)
File "C:\Users\user1\AppData\Local\Programs\Python\Python39\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 413, in init
self.callback_connector.on_trainer_init(
File "C:\Users\user1\AppData\Local\Programs\Python\Python39\lib\site-packages\pytorch_lightning\trainer\connectors\callback_connector.py", line 52, in on_trainer_init
self._configure_checkpoint_callbacks(checkpoint_callback)
File "C:\Users\user1\AppData\Local\Programs\Python\Python39\lib\site-packages\pytorch_lightning\trainer\connectors\callback_connector.py", line 77, in _configure_checkpoint_callbacks
raise MisconfigurationException(error_msg)
pytorch_lightning.utilities.exceptions.MisconfigurationException: Invalid type provided for checkpoint_callback: Expected bool but received <class 'pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint'>. Pass callback instances to the callbacks argument in the Trainer constructor instead.
PS C:\Users\user1\Documents\repo\DAC>

Trained model available?

I would like to add DA labels to my own dataset, using this model. I don't have enough computing power available to train my own model. Is a pretrained model available?

Dataset

Thank you so much for sharing the work! I'm wondering may I ask two questions about the dataset?

  1. How are the sentences split? As I observed, it might be split by sentence, but I'm not sure and it wasn't specified in the original paper.

  2. What do the tags mean? Since I'm not an expert in NLP, I don't have much domain knowledge so I'd really appreciate it if there could be some documentation that I can refer to about the tags.

Thank you!

training seems to be too slow?

Hi, I have a system with Razor Threadripper 24 core processor and Titan RTX card. But when I run training script it takes more than 6 seconds per iteration. Everything is setup with your default parameters. If I train it for 100 epochs it would potentially take 600 hours :)

Epoch 0: 0%| | 13/3337 [01:29<6:20:41, 6.87s/it, loss=3.736, v_num=48]

Is this normal, or there is something we can tune up to improve performance? Thanks

What accuracy can you actually achieve?

Hi! I just cloned and ran your code but after half a day's trainig the accuracy on dev set just achieved 67%, way beyond what was reported in the original paper, which was 82% on test set.
Also, when I reproduced the model with my own code and data proprocessing techniques, the best I can achieve with hierarchical GRU was around 76% on dev set and 74% on test set.
Is the accuracy reported in the paper truly reproducable? Has anyone spotted the same issue with me?
Thanks

Low Kaggle Accuracy

When Kaggle trains the algorithm it only goes through 5 epochs due to earlystopping. Unfortunately, this means that it's accuracy is often sub-70 on SwDa datasets. Is it possible for kaggle to create longer checkpoints?

Export Models to ONNX (Feature Request)

As mentioned in #5 , the repository currently doesn't have a class for inferencing new data, and requires tweaking the model definitions to make this possible. It would greatly benefit the users if the repo had this out-of-the-box functionality.

I suggest a function be created to export a checkpoint model to ONNX, as is standard in pytorch-lightning for inferencing in production (from docs), along with any required compatibility changes to the model definitions.

torch.onnx.export(model,
                  args=example_input,
                  f=outname,
                  input_names=['input_ids', 'attention_mask', 'seq_len'],
                  output_names=['label'],
                  export_params=True)

Ideally a separate module e.g. inference.py can be written with optimized imports and a class to preprocess (tokenize) and run inferences from a collection of dialog strings.

incompatible dataset labels

Labels for the train, valid and test datasets are created independently of each other. See

classes = sorted(set(self.acts))

Even if sorted, they won't be compatible if the validate/test splits don't contain all the labels.

For the switchboard dataset, the test data contains 5 less labels than the train data, so the predictions will be off.

SwDA scores?

Thanks for the implementation!
What scores did you achieve in the SwDA dataset? Do you reach the original paper's result?

Withdraw fake papers!

Can anyone achieve the same accuracy as reported in these papers?
I suspect no one can really do that.
These numbers are probably made by the authors.
They should be withdrawn.

Classifier ouputs `hidden_size` values

If I understand correctly, the last layer in the ContextAwareDAC is the classifier, and should output as many values as classes there are, as suggested by this line also:

nn.Linear(in_features=128, out_features=num_classes)

However, the constructor of the model seems to pass a number of classes that's equal to the hidden_size parameter in config:

num_classes=self.config['hidden_size'],

I am guessing this is a typo, and it should say instead

num_classes=self.config['num_classes']

Also, there's no other use of config['num_classes'] anywhere else in the repository, which makes me suspect the typo even more.

Prediction example?

Hi @macabdul9 , after training a model that reached 75% on the test set, I'd like to use it to predict Speech Acts for other data. Can you kindly offer an example on how to make inference with the stored checkpoing? So far haven't been able to process the new data (from a csv file) accordingly. I try using a dataloader but it doesn't seem to work.

Conversations / utterances

Hi there,

Very basic question, apologies if I am missing something.

Looking at your data (switchboard_train.csv), I am having a hard time seeing how the utterances are arranged in conversations to achieve the conversation-level context representation. It seems to be a list of utterances, with no hierarchical structure. Using your data generator (DADataset) with this input file yields 193K samples, which is equal to the number of utterances reported for the dataset (Table 2). Is this expected?

Training fails when the last batch has only one sample.

If the reminder of the size of the training/validation/test over the batch size is 1.
In my usecase, the validation set has 18753 elements, so using a batch size of 16 leaves only one element in the lsat batch, and the following error occurs:

Epoch 0: 100%|█████████▉| 12208/12209 [1:55:15<00:00,  1.77it/s, loss=1.960, v_num=gsq1]Traceback (most recent call last):                                                   
  File "main.py", line 53, in <module> [04:24<00:00,  4.44it/s]                                                                                                              
    trainer.fit(model)                                                                                                                                                       
  File "/root/CASA-Dialogue-Act-Classifier/venv/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 444, in fit                                          
    results = self.accelerator_backend.train()                                                                                                                               
  File "/root/CASA-Dialogue-Act-Classifier/venv/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py", line 63, in train                            
    results = self.train_or_test()                                                                                                                                           
  File "/root/CASA-Dialogue-Act-Classifier/venv/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 74, in train_or_test                        
    results = self.trainer.train()                                                                                                                                           
  File "/root/CASA-Dialogue-Act-Classifier/venv/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 493, in train                                        
    self.train_loop.run_training_epoch()           
  File "/root/CASA-Dialogue-Act-Classifier/venv/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 589, in run_training_epoch
    self.trainer.run_evaluation(test_mode=False)                                       
  File "/root/CASA-Dialogue-Act-Classifier/venv/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 578, in run_evaluation
    output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)         
  File "/root/CASA-Dialogue-Act-Classifier/venv/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 171, in evaluation_step
    output = self.trainer.accelerator_backend.validation_step(args)                                                                                                 [52/1614]
  File "/root/CASA-Dialogue-Act-Classifier/venv/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py", line 87, in validation_step
    output = self.__validation_step(args)
  File "/root/CASA-Dialogue-Act-Classifier/venv/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py", line 95, in __validation_step
    output = self.trainer.model.validation_step(*args)
  File "/root/CASA-Dialogue-Act-Classifier/Trainer.py", line 82, in validation_step
    loss = F.cross_entropy(logits, targets)
  File "/root/CASA-Dialogue-Act-Classifier/venv/lib/python3.7/site-packages/torch/nn/functional.py", line 2468, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "/root/CASA-Dialogue-Act-Classifier/venv/lib/python3.7/site-packages/torch/nn/functional.py", line 2260, in nll_loss
    if input.size(0) != target.size(0):
IndexError: dimension specified as 0 but tensor has no dimensions

Does dataset class still have some problems?

python main.py
wandb: WARNING W&B installed but not logged in. Run wandb login or set the WANDB_API_KEY env variable.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 3
wandb: You chose 'Don't visualize my results'
wandb: Offline run mode, not syncing to the cloud.
wandb: W&B is disabled in this directory. Run wandb on to enable cloud syncing.

| Name | Type | Params

0 | model | ContextAwareDAC | 130 M
Validation sanity check: 0it [00:00, ?it/s]wandb: WARNING W&B installed but not logged in. Run wandb login or set the WANDB_API_KEY env variable.
wandb: WARNING W&B installed but not logged in. Run wandb login or set the WANDB_API_KEY env variable.
wandb: WARNING W&B installed but not logged in. Run wandb login or set the WANDB_API_KEY env variable.
wandb: WARNING W&B installed but not logged in. Run wandb login or set the WANDB_API_KEY env variable.
Traceback (most recent call last):
File "main.py", line 52, in
trainer.fit(model)
File "/Users/zhaonan8/github_project/CASA-Dialogue-Act-Classifier/my_proj_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 444, in fit
results = self.accelerator_backend.train()
File "/Users/zhaonan8/github_project/CASA-Dialogue-Act-Classifier/my_proj_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/cpu_accelerator.py", line 57, in train
results = self.train_or_test()
File "/Users/zhaonan8/github_project/CASA-Dialogue-Act-Classifier/my_proj_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 74, in train_or_test
results = self.trainer.train()
File "/Users/zhaonan8/github_project/CASA-Dialogue-Act-Classifier/my_proj_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 466, in train
self.run_sanity_check(self.get_model())
File "/Users/zhaonan8/github_project/CASA-Dialogue-Act-Classifier/my_proj_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 658, in run_sanity_check
_, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches)
File "/Users/zhaonan8/github_project/CASA-Dialogue-Act-Classifier/my_proj_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 566, in run_evaluation
for batch_idx, batch in enumerate(dataloader):
File "/Users/zhaonan8/github_project/CASA-Dialogue-Act-Classifier/my_proj_env/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 435, in next
data = self._next_data()
File "/Users/zhaonan8/github_project/CASA-Dialogue-Act-Classifier/my_proj_env/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1085, in _next_data
return self._process_data(data)
File "/Users/zhaonan8/github_project/CASA-Dialogue-Act-Classifier/my_proj_env/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1111, in _process_data
data.reraise()
File "/Users/zhaonan8/github_project/CASA-Dialogue-Act-Classifier/my_proj_env/lib/python3.8/site-packages/torch/_utils.py", line 428, in reraise
raise self.exc_type(msg)
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/Users/zhaonan8/github_project/CASA-Dialogue-Act-Classifier/my_proj_env/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
data = fetcher.fetch(index)
File "/Users/zhaonan8/github_project/CASA-Dialogue-Act-Classifier/my_proj_env/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/Users/zhaonan8/github_project/CASA-Dialogue-Act-Classifier/my_proj_env/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/Users/zhaonan8/github_project/CASA-Dialogue-Act-Classifier/dataset/dataset.py", line 31, in getitem
target = DADataset._label_dict[label]
KeyError: 'fo_o_fw
"_by_bc'

wandb: Waiting for W&B process to finish, PID 3182
wandb: Program failed with code 1.
wandb: Find user logs for this run at: ./wandb/offline-run-20210127_191806-1pex8c61/logs/debug.log
wandb: Find internal logs for this run at: ./wandb/offline-run-20210127_191806-1pex8c61/logs/debug-internal.log
wandb: You can sync this run to the cloud by running:
wandb: wandb sync ./wandb/offline-run-20210127_191806-1pex8c61

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.