Git Product home page Git Product logo

iba's Introduction

IBA: Informational Bottlenecks for Attribution

[Paper Arxiv] | [Paper Code] | [Reviews] | [API Documentation] | [Examples] | [Installation]

Build Status Documentation Status

Example GIF
Iterations of the Per-Sample Bottleneck

This repository contains an easy-to-use implementation for the IBA attribution method. Our methods minimizes the amount of transmitted information while retaining a high classifier score for the explained class. In our paper, we run this optimization per single sample (Per-Sample Bottleneck) and trained a neural network to predict the relevant areas (Readout Bottleneck). See our paper for a in-depth description: "Restricting the Flow: Information Bottlenecks for Attribution".

Generally, we advise using the Per-Sample Bottleneck over the Readout Bottleneck. We saw it to perform better and is more flexible as it only requires to estimate the mean and variance of the feature map. The Readout Bottleneck has the advantage of producing attribution maps with a single forward pass once trained.

For the code to reproduce our paper, see IBA-paper-code.

This library provides a TensorFlow v1 and a PyTorch implementation.

PyTorch

Examplary usage for the Per-Sample Bottleneck:

from IBA.pytorch import IBA, tensor_to_np_img, get_imagenet_folder, imagenet_transform
from IBA.utils import plot_saliency_map, to_unit_interval, load_monkeys

from torch.utils.data import DataLoader
from torchvision.models import vgg16
import torch

# imagenet_dir = /path/to/imagenet/validation

# Load model
dev = 'cuda:0' if  torch.cuda.is_available() else 'cpu'
model = vgg16(pretrained=True)
model.to(dev)

# Add a Per-Sample Bottleneck at layer conv4_1
iba = IBA(model.features[17])

# Estimate the mean and variance of the feature map at this layer.
val_set = get_imagenet_folder(imagenet_dir)
val_loader = DataLoader(val_set, batch_size=64, shuffle=True, num_workers=4)
iba.estimate(model, val_loader, n_samples=5000, progbar=True)

# Load Image
monkeys, target = load_monkeys(pil=True)
monkeys_transform = imagenet_transform()(monkeys)

# Closure that returns the loss for one batch
model_loss_closure = lambda x: -torch.log_softmax(model(x), dim=1)[:, target].mean()

# Explain class target for the given image
saliency_map = iba.analyze(monkeys_transform.unsqueeze(0).to(dev), model_loss_closure, beta=10)

# display result
model_loss_closure = lambda x: -torch.log_softmax(model(x), 1)[:, target].mean()
heatmap = iba.analyze(monkeys_transform[None].to(dev), model_loss_closure )
plot_saliency_map(heatmap, tensor_to_np_img(monkeys_transform))

We provide a notebook with the Per-Sample Bottleneck and the Readout Bottleneck.

Tensorflow

from IBA.tensorflow_v1 import IBACopyInnvestigate, model_wo_softmax, get_imagenet_generator
from IBA.utils import load_monkeys, plot_saliency_map
from keras.applications.vgg16 import VGG16, preprocess_input

# imagenet_dir = /path/to/imagenet/validation

# load model & remove the final softmax layer
model_softmax = VGG16(weights='imagenet')
model = model_wo_softmax(model_softmax)

# after layer block4_conv1 the bottleneck will be added
feat_layer = model.get_layer(name='block4_conv1')

# add the bottleneck by coping the model
iba = IBACopyInnvestigate(
    model,
    neuron_selection_mode='index',
    feature_name=feat_layer.output.name,
)

# estimate feature mean and std
val_gen = get_imagenet_generator(imagenet_dir)
iba.fit_generator(val_gen, steps_per_epoch=50)

# load image
monkeys, target = load_monkeys()
monkeys_scaled =  preprocess_input(monkeys)

# get the saliency map and plot
saliency_map = iba.analyze(monkeys_scaled[None], neuron_selection=target)
plot_saliency_map(saliency_map, img=monkeys)

Table: Overview over the different tensorflow classes. (Task) type of task (i.e. regression, classification, unsupervised). (Layer) requires you to add a layer to the explained model. (Copy) copies the tensorflow graph.

Class Task Layer Copy Remarks
IBALayer Any Recommended
IBACopy Any Very flexible
IBACopy Classification Nice API for classification

Documentation

[PyTorch API] | [TensorFlow API]

The API documentation is hosted here.

Table: Examplary jupyter notebooks

Notebook Description
pytorch_IBA_per_sample.ipynb Per-Sample Bottleneck
pytorch_IBA_train_readout.ipynb Train a Readout Bottleneck
tensorflow_IBALayer_cifar.ipynb Train a CIFAR model containing an IBALayer
tensorflow_IBACopy_imagenet.ipynb Explains a ImageNet model
tensorflow_IBACopyInnvestigate_imagenet.ipynb innvestigate api wrapper

Installation

You can install it directly from git:

$ pip install git+https://github.com/BioroboticsLab/IBA

To install the dependencies for torch, tensorflow, tensorflow-gpu or developement dev, use the following syntax:

$ pip install git+https://github.com/BioroboticsLab/IBA[torch, dev]

For development, you can also clone the repository locally and then install in development mode:

$ git clone https://github.com/BioroboticsLab/IBA
$ cd per-sample-bottlneck
$ pip install -e .

Table: Supported versions

Package From To
TensorFlow 1.12.0 1.15.0
PyTorch 1.1.0 1.4.0

Reference

If you use this software for a scientific publication, please cite our paper:

@inproceedings{
Schulz2020Restricting,
title={Restricting the Flow: Information Bottlenecks for Attribution},
author={Karl Schulz and Leon Sixt and Federico Tombari and Tim Landgraf},
booktitle={International Conference on Learning Representations},
year={2020},
url={https://openreview.net/forum?id=S1xWh1rYwB}
}

iba's People

Contributors

berleon avatar karl-schulz avatar mungooooo avatar dependabot[bot] avatar nebw avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.