Git Product home page Git Product logo

removing-bias-in-multi-modal-classifiers's Introduction

Removing Bias in Multi-modal Classifiers: Regularization by Maximizing Functional Entropies

Dependencies

The only dependency is PyTorch. We tested it with pytroch 1.4 and 1.5, it should work with all of PyTorch versions.

Adding our regularization to multi-modal problems

Typically, multi-modal training procedure looks like:

import torch 


for image, question, label in loader:
    logits = model(image, question)
    loss = compute_some_loss(logits, label)

    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()

To use our regularization, you should change the training procedure to

import torch
from regularization import Perturbation, Regularization, RegParameters


reg_params = RegParameters()

for image, question, label in loader:
   logits = model(image, question)
   loss = compute_some_loss(logits, label)
   
   ####################### Our regularization method #######################

   expanded_logits = Perturbation.get_expanded_logits(logits, reg_params.n_samples)

   inf_image = Perturbation.perturb_tensor(image, reg_params.n_samples)
   inf_question = Perturbation.perturb_tensor(question, reg_params.n_samples)

   inf_output = model(inf_image, inf_question)
   inf_loss = torch.nn.functional.binary_cross_entropy_with_logits(inf_output, expanded_logits)
   
   gradients = torch.autograd.grad(inf_loss, [inf_image, inf_question], create_graph=True)
   grads = [Regularization.get_batch_norm(gradients[k], loss=inf_loss,
                                          estimation=reg_params.estimation) for k in range(2)]

   inf_scores = torch.stack(grads)
   reg_term = Regularization.get_regularization_term(inf_scores, norm=reg_params.norm,
                                                     optim_method=reg_params.optim_method)
   
   loss += reg_params.delta * reg_term

   #########################################################################

   optimizer.zero_grad()
   loss.backward(retain_graph=True)
   optimizer.step()

Note, delta is a scalar controlled by the use. Note, PyTorch does not allow calculating gradients for Long tensors. If the input to your model is a Long tensor (might happen for text represented by tokens), we recommend using forward hooking for the first embedding layer's output and calculating the information scores for these tensors.

VQA-CPv2 SOTA Use-case

We attach a use-case of how to add our regularization to a given model. We add our regularization term to the original git repository of the paper "Don't Take the Easy Way Out: Ensemble Based Methods for Avoiding Known Dataset Biases." They use a fork of the bottom-up attention repository.

All the code is under the vqa-cp folder.

Prerequisites

To install requirements:

pip install -r requirments.txt

Data Setup

All data should be downloaded to a 'data/' directory in the root directory of this repository.

The easiest way to download the data is to run the provided script tools/download.sh from the repository root. The features are provided by and downloaded from the original authors' repo. If the script does not work, it should be easy to examine the script and modify the steps outlined in it according to your needs. Then run tools/process.sh from the repository root to process the data to the correct format.

Training

Run the following command to train the model with our proposed regularization:

We introduce new parameters:

  1. lambda (float; default - 0.0) - scaler of the regularization term.
  2. norm (int; default - 2) - which norm to use.
  3. estimation (str; default - 'ent') - whether the regularization term will be entropy-based or variance-based.
  4. optim_method (str; default - 'max_ent') - which optimization method to use. In the paper we present only 'max_ent'.
  5. n_samples (int; default = 3) - the number of sample used to estimate the expectation.
  6. grad (bool; default = True) - whether to use gradient bound or not.
python main.py --output saved_models --seed 0 --cache_features --eval_each_epoch --inf_lambda 1e-10

Testing

The scores reported by the script are very close (within a hundredth of a percent in my experience) to the results reported by the official evaluation metric, but can be slightly different because the answer normalization process of the official script is not fully accounted for. To get the official numbers, you can run python save_predictions.py /path/to/model /path/to/output_file and the run the official VQA 2.0 evaluation on the resulting file. It is available under the eval folder.

Pre-trained model

Link to pre-trained model: https://gofile.io/d/FbLhKD.

Results

Comparison between our method to the previous state-of-the-art

Method Overall Yes/No Number Other
Learned-Mixin +H 52.013 72.580 31.117 46.968
Ours 54.55 74.03 49.16 45.82

removing-bias-in-multi-modal-classifiers's People

Contributors

itaigat avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

removing-bias-in-multi-modal-classifiers's Issues

Mismatch between README.md and train.py

Hi, @itaigat, thanks for your great work.
I am confusing with the mismatch between README.md and train.py

In train.py, it is

            if inf.lambda_ > 0:
                with torch.backends.cudnn.flags(enabled=False):
                    inf_pred = Perturbation.get_expanded_logits(pred, inf.n_samples)

                    a_inf = Perturbation.perturb_tensor(a, inf.n_samples, perturbation=False)
                    b_inf = Perturbation.perturb_tensor(b, inf.n_samples, perturbation=False)

                    inf_logits, _ = model(v, None, q, a_inf, b_inf, inf=True)

                    influence_loss = nn.functional.binary_cross_entropy_with_logits(inf_pred, inf_logits)
                    gradients = torch.autograd.grad(influence_loss, [hooks['q_net'][0], hooks['v_net'][0]],
                                                    create_graph=True)

                    grads = [Regularization.get_batch_norm(grad=gradient, loss=influence_loss) for gradient in gradients]

                    inf_scores = torch.stack(grads)
                    reg = Regularization.get_regularization_term(inf_scores, inf.norm, inf.optim_method)
                    loss += inf.lambda_ * reg

where we did not perturb image: v and question: q but try to perturb answer: a and bias: b with perturbation=False?

In README.md, it is

   expanded_logits = Perturbation.get_expanded_logits(logits, reg_params.n_samples)

   inf_image = Perturbation.perturb_tensor(image, reg_params.n_samples)
   inf_question = Perturbation.perturb_tensor(question, reg_params.n_samples)

   inf_output = model(inf_image, inf_question)
   inf_loss = torch.nn.functional.binary_cross_entropy_with_logits(inf_output, expanded_logits)
   
   gradients = torch.autograd.grad(inf_loss, [inf_image, inf_question], create_graph=True)
   grads = [Regularization.get_batch_norm(gradients[k], loss=inf_loss,
                                          estimation=reg_params.estimation) for k in range(2)]

   inf_scores = torch.stack(grads)
   reg_term = Regularization.get_regularization_term(inf_scores, norm=reg_params.norm,
                                                     optim_method=reg_params.optim_method)
   
   loss += reg_params.delta * reg_term

Getting 52.01 on VQA-CP

Hi, @itaigat , thanks for your great work. I tried to reproduce 54.55 on VQA-CP but got 52.01, using the following command:
python main.py --output saved_models --seed 0 --cache_features --eval_each_epoch --inf_lambda 4e-17
4e-17 is from https://papers.nips.cc/paper/2020/file/20d749bc05f47d2bd3026ce457dcfd8e-Supplemental.pdf

Did I miss any hyper-parameter? or anything wrong?
Thanks again!

epoch 1, time: 313.99
	train_loss: inf, score: 26.21
	eval score: 40.08 (90.22)
epoch 2, time: 332.40
	train_loss: 3.70, score: 40.59
	eval score: 47.00 (90.22)
epoch 3, time: 332.09
	train_loss: 3.36, score: 43.40
	eval score: 49.11 (90.22)
epoch 4, time: 307.44
	train_loss: 3.16, score: 45.35
	eval score: 50.93 (90.22)
epoch 5, time: 347.27
	train_loss: 3.02, score: 47.94
	eval score: 51.68 (90.22)
epoch 6, time: 355.29
	train_loss: 2.91, score: 50.06
	eval score: 52.35 (90.22)
epoch 7, time: 395.44
	train_loss: 2.81, score: 51.89
	eval score: 53.25 (90.22)
epoch 8, time: 403.33
	train_loss: 2.73, score: 53.98
	eval score: 53.24 (90.22)
epoch 9, time: 399.54
	train_loss: 2.65, score: 55.95
	eval score: 52.60 (90.22)
epoch 10, time: 400.44
	train_loss: 2.58, score: 57.57
	eval score: 53.38 (90.22)
epoch 11, time: 398.15
	train_loss: 2.52, score: 59.31
	eval score: 53.18 (90.22)
epoch 12, time: 394.01
	train_loss: 2.46, score: 61.08
	eval score: 51.72 (90.22)
epoch 13, time: 388.91
	train_loss: 2.40, score: 62.33
	eval score: 51.76 (90.22)
epoch 14, time: 385.31
	train_loss: 2.35, score: 64.18
	eval score: 51.76 (90.22)
epoch 15, time: 382.76
	train_loss: 2.30, score: 65.47
	eval score: 52.01 (90.22)

Nan loss error

Hi, thanks for your interesting work. I am trying to train the model but got ValueError("Nan loss"). How can I fix the problem?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.