Git Product home page Git Product logo

attattr's Introduction

Self-Attention Attribution

This repository contains the implementation for AAAI-2021 paper Self-Attention Attribution: Interpreting Information Interactions Inside Transformer. It includes the code for generating the self-attention attribution score, pruning attention heads with our method, constructing the attribution tree and extracting the adversarial triggers. All of our experiments are conducted on bert-base-cased model, our methods can also be easily transfered to other Transformer-based models.

Requirements

  • Python version >= 3.5
  • Pytorch version == 1.1.0
  • networkx == 2.3

We recommend you to run the code using the docker under Linux:

docker run -it --rm --runtime=nvidia --ipc=host --privileged pytorch/pytorch:1.1.0-cuda10.0-cudnn7.5-devel bash

Then install the following packages with pip:

pip install --user networkx==2.3
pip install --user matplotlib==3.1.0
pip install --user tensorboardX six numpy tqdm scikit-learn

You can install attattr from source:

git clone https://github.com/YRdddream/attattr
cd attattr
pip install --user --editable .

Download Pre-Finetuned Models and Datasets

Before running self-attention attribution, you can first fine-tune bert-base-cased model on a downstream task (such as MNLI) by running the file run_classifier_orig.py. We also provide the example datasets and the pre-finetuned checkpoints at Google Drive.

Get Self-Attention Attribution Scores

Run the following command to get the self-attention attribution score and the self-attention score.

python examples/generate_attrscore.py --task_name ${task_name} --data_dir ${data_dir} \
       --bert_model bert-base-cased --batch_size 16 --num_batch 4 \
       --model_file ${model_file} --example_index ${example_index} \
       --get_att_attr --get_att_score --output_dir ${output_dir}

Construction of Attribution Tree

When you get the self-attribution scores of a target example, you could construct the attribution tree. We recommend you to run the file get_tokens_and_pred.py to summarize the data, or you can manually change the value of tokens in attribution_tree.py.

python examples/attribution_tree.py --attr_file ${attr_file} --tokens_file ${tokens_file} \
       --task_name ${task_name} --example_index ${example_index} 

You can generate the attribution tree from the provided example.

python examples/attribution_tree.py --attr_file ${model_and_data}/mnli_example/attr_zero_base_exp16.json \
       --tokens_file ${model_and_data}/mnli_example/tokens_and_pred_100.json \
       --task_name mnli --example_index 16

Self-Attention Head Pruning

We provide the code of pruning attention heads with both our attribution method and the Taylor expansion method. Pruning heads with our method.

python examples/prune_head_with_attr.py --task_name ${task_name} --data_dir ${data_dir} \
       --bert_model bert-base-cased --model_file ${model_file}  --output_dir ${output_dir}

Pruning heads with Taylor expansion method.

python examples/prune_head_with_taylor.py --task_name ${task_name} --data_dir ${data_dir} \
       --bert_model bert-base-cased --model_file ${model_file}  --output_dir ${output_dir}

Adversarial Attack

First extract the most important connections from the training dataset.

python examples/run_adver_connection.py --task_name ${task_name} --data_dir ${data_dir} \
       --bert_model bert-base-cased --batch_size 16 --num_batch 4 --zero_baseline \
       --model_file ${model_file} --output_dir ${output_dir}

Then use these adversarial triggers to attack the original model.

python examples/run_adver_evaluate.py --task_name ${task_name} --data_dir ${data_dir} \
       --bert_model bert-base-cased --model_file ${model_file} \
       --output_dir ${output_dir} --pattern_file ${pattern_file}

Reference

If you find this repository useful for your work, you can cite the paper:

@inproceedings{attattr,
  author = {Yaru Hao and Li Dong and Furu Wei and Ke Xu},
  title = {Self-Attention Attribution: Interpreting Information Interactions Inside Transformer},
  booktitle = {The Thirty-Fifth {AAAI} Conference on Artificial Intelligence},
  publisher = {{AAAI} Press},
  year      = {2021},
  url       = {https://arxiv.org/pdf/2004.11207.pdf}
}

attattr's People

Contributors

yrdddream avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

attattr's Issues

Effectiveness Analysis Code

Hello,
It seems that there is no corresponding code for Effectiveness Analysis in Section 4.1. Would you be able to add it? Besides, I have a question about Section 4.1: The I_{h} selection principle is not clear in the paper, how do you calculate the importance of the head based on attr (max pool or avg pool)?

image

I can‘t understand the Riemman approximation compute in this paper.

When calculating the Riemann estimate, it is mentioned in the article that the effect is best when m=20, but in the scaled_input function in the code, batch_ Size=16, num_ Batch=4, I don't quite understand here, and the gradient and attention_weights, Where is the code for multiplying weights to obtain attribution score? As a novice, I still have some questions and hope you can help me solve them. Thank you very much!

batch wise attribution

Amazing work! I had a quick question:
As of right now, I see that the method works for 1 example input at a time. Can we process more than one example at a time using the generate_attrscore.py file? And if so, how do I go about doing so? Thank you!

attention tree threshold

Hi, great work. I was trying to understand the code for the attribute tree construction.
When calculating $max(Attr(A^t))$, why do you leave out the values of the 0th row. In the following section:

proportion_all = copy.deepcopy(att_all)
for i in range(len(proportion_all)):
    proportion_all[i] /= abs(proportion_all[i][1:,:].max())

Why not use all the parameters to calculate the max? Is it a part of the heuristics of removing the interaction with the [CLS] token?

Applying to time series

Hi, I would like to ask if the method in your example can be simply modified to apply to time series instead of NLP? Because I see that you use bert as the baseline, but if you use it for time series, you need another form of transformer.
Thank u so much

Getting attribution scores for unlabelled data.

Hi,
In the code for attribution score calculation in the file generate_attrscore.py, if we wanted to get the score using say SST-2 fine-tuned model. If we generate the attribution score for an input text without gold labels, i.e. making the assumption that $\text{predicted labels} = \text{true label}$. What would be the implications of it, would this affect the attribute tree generation so much as to render this approach unusable?

Referring to the file generate_attrscore.py:

tar_prob, grad = model(input_ids, segment_ids, input_mask, label_ids, tar_layer, one_batch_att, pred_label=pred_label)

for $\text{predicted labels} = \text{true label}$, changing it to:

pred_tensor = torch.argmax(baseline_logits)
tar_prob, grad = model(input_ids, segment_ids, input_mask, pred_tensor.unsqueeze(0), tar_layer, one_batch_att, pred_label=pred_label)

Interval parameter m in paper

Hi,

Catching up with this repo together with your paper. Thanks for providing great attribution methodology.
In 3rd/4th pages of your paper, found a description of approximation step parameter m, which I couldn't find the way to tune yet.

Could you point out the line of the codes you're setting it, or could you guide me a bit so that I can implement on my own?

Thanks!

attention tree error

Hi YRddream,
I am running your code with your examples and commands. But I found an error when running "attention tree" module. When I use the following command, a blank graph will be generated, and there is an error that the dividend cannot be 0. Could you help me?
Thank u so much~~

!python examples/attribution_tree.py --attr_file "/content/drive/MyDrive/AttentionAttr/attattr/data/model_and_data/model_and_data/mnli_example/attr_zero_base_exp16.json" 
       --tokens_file "/content/drive/MyDrive/AttentionAttr/attattr/data/model_and_data/model_and_data/mnli_example/tokens_and_pred_100.json" 
       --task_name mnli --example_index 16

IndexError: list index out of range

Thanks for your great codes! I enjoy reading the paper.

When reproducing results following the provided readme,

python examples/attribution_tree.py --attr_file ${model_and_data}/mnli_example/attr_zero_base_exp16.json
--tokens_file ${model_and_data}/mnli_example/tokens_and_pred_100.json
--task_name mnli --example_index 16

I encounter the bugs

Traceback (most recent call last):
File "examples/attribution_tree.py", line 145, in
main()
File "examples/attribution_tree.py", line 98, in main
with open(args.tokens_file) as fin:
IndexError: list index out of range

I found that the provided attribution json is 22*22 matrix, but the tokens of index 16 in tokens_and_pred_100.json only have 21 tokens.

['[CLS]', 'the', 'emotions', 'are', 'raw', 'and', 'will', 'strike', 'a', 'nerve', 'with', 'anyone', 'who', "'", 's', 'ever', 'had', 'family', 'trauma', '.', '[SEP]']

Thus, it incurs the error.

Could you check the provided the files or codes? Or I missed something?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.