Git Product home page Git Product logo

selfexplain's Introduction

SelfExplain Framework

The code for the SelfExplain framework (https://arxiv.org/abs/2103.12279)

Currently, this repo supports SelfExplain-XLNet and SelfExplain-RoBERTa version for SST-2 dataset, SST-5 dataset, and SUBJ dataset. We have also tested it with CoLA, which only RoBERTa provide reasonable performance because sentences in the CoLA are too short for XLNet.

Preprocessing

Data for preprocessing available in data/ folder

On a python shell, do the following for installing the parser

>>> import benepar
>>> benepar.download('benepar_en3')
sh scripts/run_preprocessing.sh

For preprocessing, we want to point out that we will need to adjust the hyperparameters on the top. We have created two separate folders in data folder: RoBERTa-SST-2 and XLNet-SST-2. We expect users follow this practice because concept store are unique for each Transformer-based classifier and each dataset.

Please comfirm DATA_FOLDER is the correct path. Please comfirm TOKENIZER_NAME is the correct tokenizer you would like to use. (roberta-base or xlnet-base-cased). Please comfirm MAX_LENGTH because this will affect the number of concepts. If MAX_LENGTH is
small and average length for dataset is long, you may end up in training errors.

Example:

export DATA_FOLDER='data/SST-2-XLNet'
export TOKENIZER_NAME='xlnet-base-cased'
export MAX_LENGTH=5

Note if you wish to parse test.tsv please edit process_trec_dataset.py at line 57. Note we have provided data for SST-2 and SUBJ.

Training

For training, please edit data path and control other parameters.

sh scripts/run_self_explain.sh

Example:

python model/run.py --dataset_basedir data/RoBERTa-SST-2 \
                         --lr 2e-5  --max_epochs 5 \
                         --gpus 1 \
                         --model_name roberta-base \
                         --concept_store data/RoBERTa-SST-2/concept_store.pt \
                         --topk 5 \
                         --gamma 0.1 \
                         --lamda 0.1

Note the specified model_name should accord with the tokenizer used in the pre-processing stage.

Generation (Inference)

The Original author claims this is in developing setting. We have utilized it and it works well.

 python model/infer_model.py
        --ckpt $PATH_TO_BEST_DEV_CHECKPOINT \
        --concept_map $DATA_FOLDER/concept_idx.json \ 
        --batch_size $BS \
        --paths_output_loc $PATH_TO_OUTPUT_PREDS \
        --dev_file $PATH_TO_DEV_FILE

Example:

 python model/infer_model.py 
      --ckpt lightning_logs/version_3/checkpoints/epoch=2-step=1499-val_acc_epoch=0.9570.ckpt \
      --concept_map data/RoBERTa-SST-2/concept_idx.json \
      --paths_output_loc result/result_roberta_7.csv \
      --dev_file data/RoBERTa-SST-2/dev_with_parse.json \
      --batch_size 16

Citation

@inproceedings{rajagopal-etal-2021-selfexplain,
    title = "{SELFEXPLAIN}: A Self-Explaining Architecture for Neural Text Classifiers",
    author = "Rajagopal, Dheeraj  and
      Balachandran, Vidhisha  and
      Hovy, Eduard H  and
      Tsvetkov, Yulia",
    booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing",
    month = nov,
    year = "2021",
    address = "Online and Punta Cana, Dominican Republic",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2021.emnlp-main.64",
    doi = "10.18653/v1/2021.emnlp-main.64",
    pages = "836--850",
}

selfexplain's People

Contributors

dheerajrajagopal avatar mk322 avatar xuweiyichen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

selfexplain's Issues

Inaccurate loss

According to the paper, section 2.5, the final loss is calculated as a weighted combination of the loss terms, LIL loss, GIL loss and task-based CE loss. However, in the code, the logits are calculated as the weighted sum of LIL, GIL and task-based logits BEFORE computation of the final loss. Due to this, the two are not equivalent.

i.e.; log(softmax(a+b)) is not equivalent to log(softmax(a)) + log(softmax(b))

LIL Implementation

phrase_level_logits = self.phrase_logits(phrase_level_activations)

The implementation of LIL differs from what is in the paper. I am a bit confused on that aspect as well. If we are going via this implementation then mean that we are taking is not actually division by len(nt) matrix.

Could you provide a colab notebook to running all your code?

Thank you very much for your coding.

However, I am stuck trying to run the training and inference steps to obtain an explanation. Could you provide a notebook that contains your code? It could be a Colab notebook with a small example that demonstrates how to obtain an explanation.

I have made several attempts to install the required library, but it still fails every time.

Thank you so much.

Missing normalization based on phrase length?

According to the paper (section 2.2), constituent word representations are taken to be the average of token representations of the phrase (non-terminal) tokens.

The code actually does a batch matrix multiplication, and therefore achieves the sum of hidden token representations. This may affects both the magnitude and the direction of the phrase level representation after applying the activation.

Am I missing something?

Not runable in Windows

Hi,

I could not run_self_explain.sh due to "import resource" from line 8, run.py.
"import resource" is only available in UNIX systems, but not Windows systems. Is there any way that I can fix it?

Thanks!

GIL implementation

hi, I'm trying to understand your code and reproduce it. in the GIL implementation, I have a question:
your paper says q_{k} in concept store Q will be constantly updated.
(As the model M is finetuned for a downstream task, the representations qk are constantly updated)
however, after looking at the bash files and python files, I found that you build the concept store at the very beginning using the original XLNet checkpoints then save it as a static concept_store.pt file. During the training, it seems that you did not update the .pt file.
I'm a bit wonder here. did I miss any detail here? or maybe can you point out where is the function for updating embeddings in Q?
Thanks in advanced!

Noise on LIL layer due to batching

The codebase uses batching to process multiple sentences at the same time. Each sentence can be broken down into multiple phrases, represented by non-terminal onehot vectors. Because not ALL sentences in a batch contain the same number of phrase decomposition, some sentences have empty phrases. Ideally, lil_logits should MASK out all such representations and these should not account for the total loss. If not, some the lil_logits_mean will be dominated by 0 - pooled_seq_rep.

XLNet has <cls> in last

phrase_level_activations = phrase_level_activations - self.activation(hidden_state[:,0,:].unsqueeze(1))

I am a bit confused on this line. XLNet has CLS token in last. hidden state first token would capture the first token.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.