Git Product home page Git Product logo

dialog-probing's Introduction

Probing Neural Dialog Systems for Conversational Understanding

This code accompanies the paper "Probing Neural Dialog Models for Conversational Understanding" by Saleh, et al., 2020.

This repo is built on top of ParlAI. We add functionality for probing open-domain dialog models (RNNs and Transformers). Probing evaluates the quality of internal model representations for conversational skills.

Setup

Follow same installation instructions as ParlAI. ParlAI requires Python 3 and PyTorch 1.1 or newer. After cloning this repo, remember to run

python setup.py develop

You will also need to install skorch 0.6 which is required by the probing classifier.

Example Usage

This section takes you through an example of how you would train and probe a dialog model.

  1. You will first need a model to probe. Let's train a small RNN on the DailyDialog dataset:

    python examples/train_model.py -t dailydialog -m seq2seq --bidirectional true --numlayers 2 --hiddensize 256 --embeddingsize 128 -eps 60 -veps 1 -vp 10 -bs 32 --optimizer adam --lr-scheduler invsqrt -lr 0.005 --dropout 0.3 --warmup-updates 4000 -tr 300  -mf trained/dailydialog/seq2seq/seq2seq --display-examples True -ltim 30 --tensorboard_log True --validation-metric ppl
    

    To train on perturbed (i.e. shuffled) data, add the flag -sh within. See ParlAI's documentation for more information about training dialog models.

  2. You will then generate and save the vector representations to be used as features by the probing classifier. Let's generate and save the encoder_state vectors for the TREC question classification task:

    python probing/probe_model.py -mf trained/dailydialog/seq2seq/seq2seq -t probing.tasks.trecquestion.agents --probe encoder_state 
    

    This will automatically download the required task data and save the generated representations at trained/dailydialog/seq2seq/probing/encoder_state/trecquestion.

    There are three types of hidden representations you can extract and probe:

    • encoder_state
    • word_embeddings
    • combined

    Refer to the paper for more info.

  3. Now you can run the probing classifier to evaluate the quality of the generated representations by running:

    python probing/eval_probing.py -m trained/dailydialog/seq2seq -t trecquestion --probing-module encoder_state --max_epochs 50 --runs 30
    

    This trains the probing classifier (an MLP) on the generated representations. The final results are saved at trained/dailydialog/seq2seq/probing/encoder_state/trecquestion/results.json.

Baselines

  1. You might also want to generate the GloVe word embedding baselines. You can do this by running:

    python probing/glove.py -t trecquestion --dict-path trained\dailydialog\seq2seq\seq2seq.dict
    

    This will automatically download the GloVe embeddings and save the generated representations to trained/GloVe/probing/trecquestion

  2. Now you need to run the probing classifier on these generated representations using:

    python probing/eval_probing.py -m GloVe -t trecquestion --max_epochs 50 --runs 30
    

    The final results are saved at trained/GloVe/probing/trecquestion/results.json.

Important note: It's strongly recommended you use the same directories as in the examples to save your checkpoints and data, otherwise other parts of the code will break.

Probing tasks

The supported probing tasks are:

  • trecquestion
  • multiwoz
  • sgd
  • dialoguenli
  • wnli
  • snips
  • scenariosa
  • dailydialog_topic

Adding new probing tasks

New probing tasks need to be in the following format:

text: <utterance1> \n
<utterance2> \t episode_done:True \n  
text: <utterance1> \n
<utterance2> \n
<utterance3> \t episode_done:True \n 
... 

New probing tasks need to be added to probing/tasks and glove.py.

The code is best suited for tasks where the label is based on:

  • all utterances in a dialog (like DailyDialog Topic)
  • or the interaction between two utterances (like DialogueNLI)
  • or the last utterance in a dialog (like ScenarioSA)

See section 3.2 in the paper for more info.

Code Organization

Most of the code for this study exists within the probing directory.

We also augmented torch_generator_agent.py with probing functions that extract a model's internal representations.

Reference

If you use this code please cite our paper:

@article{saleh2020probing,
    author= {Saleh, Abdelrhman and Deutsch, Tovly and Casper, Stephen and Belinkov, Yonatan and Shieber, Stuart},
    title= {Probing Neural Dialog Models for Conversational Understanding},
    journal= {Second Workshop on NLP for Conversational AI},
    year= {2020}
    }

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.