Git Product home page Git Product logo

legrad's Introduction

LeGrad

Walid Bousselham1, Angie Boggust2, Sofian Chaybouti1, Hendrik Strobelt3,4 and Hilde Kuehne1,3

1 University of Bonn & Goethe University Frankfurt, 2 MIT CSAIL, 3 MIT-IBM Watson AI Lab, 4 IBM Research.

Hugging Face Spaces

Vision-Language foundation models have shown remarkable performance in various zero-shot settings such as image retrieval, classification, or captioning. we propose LeGrad, an explainability method specifically designed for ViTs. We LeGrad we explore how the decision-making process of such models by leveraging their feature formation process. A by-product of understanding VL models decision-making is the ability to produce localised heatmap for any text prompt.

The following is the code for a wrapper around the OpenCLIP library to equip VL models with LeGrad.

๐Ÿ”จ Installation

legrad library can be simply installed via pip:

$ pip install legrad_torch

Demo

To run the gradio app locally, first install gradio and then run app.py:

$ pip install gradio
$ python app.py

Usage

To see which pretrained models is available use the following code snippet:

import legrad
legrad.list_pretrained()

Single Image

To process an image and a text prompt use the following code snippet:

Note: the wrapper does not affect the original model, hence all the functionalities of OpenCLIP models can be used seamlessly.

import requests
from PIL import Image
import open_clip
import torch

from legrad import LeWrapper, LePreprocess
from legrad.utils import visualize

# ------- model's paramters -------
model_name = 'ViT-B-16'
pretrained = 'laion2b_s34b_b88k'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ------- init model -------
model, _, preprocess = open_clip.create_model_and_transforms(
    model_name=model_name, pretrained=pretrained, device=device)
tokenizer = open_clip.get_tokenizer(model_name=model_name)
model.eval()
# ------- Equip the model with LeGrad -------
model = LeWrapper(model)
# ___ (Optional): Wrapper for Higher-Res input image ___
preprocess = LePreprocess(preprocess=preprocess, image_size=448)

# ------- init inputs: image + text -------
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = preprocess(Image.open(requests.get(url, stream=True).raw)).unsqueeze(0).to(device)
text = tokenizer(['a photo of a cat']).to(device)

# -------
text_embedding = model.encode_text(text, normalize=True)
print(image.shape)
explainability_map = model.compute_legrad_clip(image=image, text_embedding=text_embedding)

# ___ (Optional): Visualize overlay of the image + heatmap ___
visualize(heatmaps=explainability_map, image=image)

โญ Acknowledgement

This code is build as wrapper around OpenCLIP library from LAION, visit their repo for more vision-language models. This project also takes inspiration from Transformer-MM-Explainability and the timm library, please visit their repository.

๐Ÿ“š Citation

If you find this repository useful, please consider citing our work ๐Ÿ“ and giving a star ๐ŸŒŸ :

@article{bousselham2024legrad,
  author    = {Bousselham, Walid and Boggust, Angie and Chaybouti, Sofian and Strobelt, Hendrik and Kuehne, Hilde}
  title     = {LeGrad: An Explainability Method for Vision Transformers via Feature Formation Sensitivity},
  journal   = {arXiv preprint arXiv:2404.03214},
  year      = {2024},
}

legrad's People

Contributors

raoulritter avatar jessewiers avatar walbouss avatar humzaiqbal avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.