Git Product home page Git Product logo

Comments (12)

szagoruyko avatar szagoruyko commented on July 21, 2024 10

we have a colab in preparation, should be ready within the next few days, stay tuned

from detr.

defqoon avatar defqoon commented on July 21, 2024 3

for anyone interested, you can replicate the figure by registering a forward hook on the last multihead attention layer of the decoder.

out = {}
def get_output(name):
    def hook(model, input, output):
        out[name] = output[1].detach()
    return hook

model.transformer.decoder.layers[5].multihead_attn.register_forward_hook(get_output('last_decoder'))

after a forward pass :

act_decoder = out['last_decoder'].squeeze()

will get you the heatmap (need to resize and select the map corresponding to an object though).

from detr.

fmassa avatar fmassa commented on July 21, 2024 3

The PR has been merged, so you you can access the colab with
https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_attention.ipynb

Let us know if you have further questions!

from detr.

defqoon avatar defqoon commented on July 21, 2024 1

@lessw2020 I tried cleaning my code a bit:

# forward pass
d = transforms(image)
model.eval()
pred = model(d.unsqueeze(0))

# post processing
postprocessors = {'bbox': PostProcess()}
_, h, w = d.shape
preds = postprocessors['bbox'](pred, torch.tensor([h, w]).unsqueeze(0))[0]

# find the high score proposals
obj_ids = torch.where(preds['scores'] > 0.9)[0]

Then you can resize the reshaped maps (where maps is your torch.Size([100, 1, 15, 20]) tensor ) and display them

upsampled_maps = F.interpolate(maps, (h, w), mode='bilinear', align_corners=False)[obj_ids, 0]

from detr.

lessw2020 avatar lessw2020 commented on July 21, 2024

@szagoruyko - colab sounds fantastic! Thanks very much and look forward to it.

from detr.

defqoon avatar defqoon commented on July 21, 2024

any update on this? looking forward to try the colab!

from detr.

lessw2020 avatar lessw2020 commented on July 21, 2024

@thomashossler - thanks for the informative post.
I've hooked and gotten the activations but I'm unclear re: the last step of "select the map corresponding to an object" and how to display.
Would you be able to post a full working example to just finish off the process?

from detr.

defqoon avatar defqoon commented on July 21, 2024

@lessw2020 so you need to first reshape the attention map. If you are using the elephant image without any resizing, your attention map should have a shape of torch.Size([1, 100, 300]). When reshaping, you get a map of shape torch.Size([100, 1, 15, 20]). The first dimensions is the number of proposals.

Now when you forward pass the image, you should look at the labels predictions. For the elephant image, only 2 of the 100 proposals have a predicted label of 22 (elephant class id). You can use the indices of these two proposals to extract the attention maps of the objects. You just need to resize them to the image size and you should be done.

from detr.

lessw2020 avatar lessw2020 commented on July 21, 2024

@thomashossler - thanks very much.
You mean look at the output tensor and use the predictions from that to index in reverse back to the attention maps?

from detr.

fmassa avatar fmassa commented on July 21, 2024

Hi,

FYI I've sent a PR in #112 which adds a Colab notebook for visualizing both self-attention in the encoder as well as inter-attention in the decoder.

from detr.

lessw2020 avatar lessw2020 commented on July 21, 2024

Thanks very much for this colab @fmassa!
(And thank you @thomashossler for posting the additional code!)

Question for @fmassa - What cmap or matplotlib settings did you guys use to make the glow coloration in the paper examples (attached)?
I used jet and magma cmap with an alpha blend for putting my heatmaps onto my images (from the attention decoder map, now that I understand how to hook those thanks to all the info posted here), but I like the look of your color coded glow better :)

detr-glow-maps

from detr.

fmassa avatar fmassa commented on July 21, 2024

@lessw2020 I believe we used seaborn for that (but @szagoruyko can correct me if I'm wrong)

import seaborn as sns
colors = sns.color_palette(n_colors=n_objects)

from detr.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.