Git Product home page Git Product logo

intr's Introduction

INTR: A Simple Interpretable Transformer for Fine-grained Image Classification and Analysis (ICLR 2024)

This repo is the official implementation of INTR: A Simple Interpretable Transformer for Fine-grained Image Classification and Analysis. It currently includes code and models for the interpretation of fine-grained data. We will provide a link to the upcoming ICLR 2024 proceedings for this paper when it becomes available online.

INTR is a novel usage of Transformers to make image classification interpretable. In INTR, we investigate a proactive approach to classification, asking each class to look for itself in an image. We learn class-specific queries (one for each class) as input to the decoder, allowing them to look for their presence in an image via cross-attention. We show that INTR intrinsically encourages each class to attend distinctly; the cross-attention weights thus provide a meaningful interpretation of the model's prediction. Interestingly, via multi-head cross-attention, INTR could learn to localize different attributes of a class, making it particularly suitable for fine-grained classification and analysis.

Image Description

In the INTR model, each query in the decoder is responsible for the prediction of a class. So, a query looks at itself to find class-specific features from the feature map. First, we visualize the feature map i.e., the value matrix of the transformer architecture to see the important parts of the object in the image. To find the specific features, where the model pays attention in the value matrix, we show the heatmap of the attention of the model. To avoid external interference in the classification, we use a shared weight vector for classification so therefore the attention weight explains the model's prediction.

Image Description

Fine-tune models and results

INTR on DETR-R50 backbone, classification performance, and fine-tuned models on different datasets.

Dataset acc@1 acc@5 Model
CUB 71.8 89.3 checkpoint download
Bird 97.4 99.2 checkpoint download
Butterfly 95.0 98.3 checkpoint download

Installation Instructions

Create python environment (optional)

conda create -n intr python=3.8 -y
conda activate intr

Clone the repository

git clone https://github.com/dipanjyoti/INTR.git
cd INTR

Install python dependencies

pip install -r requirements.txt

Data Preparation

Follow the below format for data.

datasets
├── dataset_name
│   ├── train
│   │   ├── class1
│   │   │   ├── img1.jpeg
│   │   │   ├── img2.jpeg
│   │   │   └── ...
│   │   ├── class2
│   │   │   ├── img3.jpeg
│   │   │   └── ...
│   │   └── ...
│   └── val
│       ├── class1
│       │   ├── img4.jpeg
│       │   ├── img5.jpeg
│       │   └── ...
│       ├── class2
│       │   ├── img6.jpeg
│       │   └── ...
│       └── ...

INTR Evaluation

To evaluate the performance of INTR on the CUB dataset, on a multi-GPU (e.g., 4 GPUs) settings, execute the below command. INTR checkpoints are available at Fine-tune model and results.

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port 12345 --use_env main.py --eval --resume <path/to/intr_checkpoint_cub_detr_r50.pth> --dataset_path <path/to/datasets> --dataset_name <dataset_name> 

INTR Interpretation

To generate visual representations of the INTR's interpretations, execute the provided command below. This command will present the interpretation for a specific class with the index <class_number>. By default, it will display interpretations from all attention heads. To focus on interpretations associated with the top queries labeled as top_q as well, set the parameter sim_query_heads to 1. Use a batch size of 1 for the visualization.

python -m tools.visualization --eval --resume <path/to/intr_checkpoint_cub_detr_r50.pth> --dataset_path <path/to/datasets> --dataset_name <dataset_name> --class_index <class_number>

Inference time single-image prediction and visualization: We've also provided a Jupyter Notebook, demo.ipynb, designed for single-image prediction and visualization during the inference process. Please note that the demo is focused on the CUB dataset.

INTR Training

To prepare INTR for training, use the pretrained model DETR-R50. To train for a particular dataset, modify '--num_queries' by setting it to the number of classes in the dataset. Within the INTR architecture, each query in the decoder is assigned the task of capturing class-specific features, which means that every query can be adapted through the learning process. Consequently, the total number of model parameters will grow in proportion to the number of classes in the dataset. To train INTR on a multi-GPU system, (e.g., 4 GPUs), execute the command below.

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port 12345 --use_env main.py --finetune <path/to/detr-r50-e632da11.pth> --dataset_path <path/to/datasets> --dataset_name <dataset_name> --num_queries <num_of_classes>

Acknowledgment

Our model is inspired by the DEtection TRansformer (DETR) method.

We thank the authors of DETR for doing such great work.

Bibtext Paper

If you find our work helpful for your research, please consider citing the BibTeX entry.

@inproceedings{paul2024simple,
  title={A Simple Interpretable Transformer for Fine-Grained Image Classification and Analysis},
  author={Paul, Dipanjyoti and Chowdhury, Arpita and Xiong, Xinqi and Chang, Feng-Ju and Carlyn, David and Stevens, Samuel and Provost, Kaiya and Karpatne, Anuj and Carstens, Bryan and Rubenstein, Daniel and Stewart, Charles and Berger-Wolf, Tanya and Su, Yu and Chao, Wei-Lun},
  booktitle={International Conference on Learning Representations},
  year={2024}
}

intr's People

Contributors

dipanjyoti avatar davidcarlyn avatar egrace479 avatar thompsonmj avatar pujols avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.