Comments (12)
we have a colab in preparation, should be ready within the next few days, stay tuned
from detr.
for anyone interested, you can replicate the figure by registering a forward hook on the last multihead attention layer of the decoder.
out = {}
def get_output(name):
def hook(model, input, output):
out[name] = output[1].detach()
return hook
model.transformer.decoder.layers[5].multihead_attn.register_forward_hook(get_output('last_decoder'))
after a forward pass :
act_decoder = out['last_decoder'].squeeze()
will get you the heatmap (need to resize and select the map corresponding to an object though).
from detr.
The PR has been merged, so you you can access the colab with
https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_attention.ipynb
Let us know if you have further questions!
from detr.
@lessw2020 I tried cleaning my code a bit:
# forward pass
d = transforms(image)
model.eval()
pred = model(d.unsqueeze(0))
# post processing
postprocessors = {'bbox': PostProcess()}
_, h, w = d.shape
preds = postprocessors['bbox'](pred, torch.tensor([h, w]).unsqueeze(0))[0]
# find the high score proposals
obj_ids = torch.where(preds['scores'] > 0.9)[0]
Then you can resize the reshaped maps (where maps
is your torch.Size([100, 1, 15, 20])
tensor ) and display them
upsampled_maps = F.interpolate(maps, (h, w), mode='bilinear', align_corners=False)[obj_ids, 0]
from detr.
@szagoruyko - colab sounds fantastic! Thanks very much and look forward to it.
from detr.
any update on this? looking forward to try the colab!
from detr.
@thomashossler - thanks for the informative post.
I've hooked and gotten the activations but I'm unclear re: the last step of "select the map corresponding to an object" and how to display.
Would you be able to post a full working example to just finish off the process?
from detr.
@lessw2020 so you need to first reshape the attention map. If you are using the elephant image without any resizing, your attention map should have a shape of torch.Size([1, 100, 300])
. When reshaping, you get a map of shape torch.Size([100, 1, 15, 20])
. The first dimensions is the number of proposals.
Now when you forward pass the image, you should look at the labels predictions. For the elephant image, only 2 of the 100 proposals have a predicted label of 22 (elephant class id). You can use the indices of these two proposals to extract the attention maps of the objects. You just need to resize them to the image size and you should be done.
from detr.
@thomashossler - thanks very much.
You mean look at the output tensor and use the predictions from that to index in reverse back to the attention maps?
from detr.
Hi,
FYI I've sent a PR in #112 which adds a Colab notebook for visualizing both self-attention in the encoder as well as inter-attention in the decoder.
from detr.
Thanks very much for this colab @fmassa!
(And thank you @thomashossler for posting the additional code!)
Question for @fmassa - What cmap or matplotlib settings did you guys use to make the glow coloration in the paper examples (attached)?
I used jet and magma cmap with an alpha blend for putting my heatmaps onto my images (from the attention decoder map, now that I understand how to hook those thanks to all the info posted here), but I like the look of your color coded glow better :)
from detr.
@lessw2020 I believe we used seaborn
for that (but @szagoruyko can correct me if I'm wrong)
import seaborn as sns
colors = sns.color_palette(n_colors=n_objects)
from detr.
Related Issues (20)
- Question about object queries. HOT 4
- I want to train the DETR model on a CPU. How can I make it possible on a small computer, 8gb RAM HOT 3
- Why positional encoding is added to different role in encoder and decoder. HOT 1
- 🐛 Bug: Architecture diagram in README.md renders incorrectly when using dark mode
- continue training with chekckpoint
- How to finetune DETR for semantic segmentation task?
- I do not understand what the mask meaning in "samlpes"
- Process finished with exit code 137 (interrupted by signal 9: SIGKILL)Please read & provide the following
- Very low performance for segmentation task.
- box_cxcywh_to_xyxy
- ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -9) local_rank: 6 (pid: 257736) of binary: /home/public/anaconda3/envs/DL/bin/python
- Average Precision of each class for best epoch and then it's mean HOT 1
- the mAP is chage
- I think there are some errors in the posted code HOT 6
- Queries for images with low number of objects HOT 2
- RuntimeError: Error(s) in loading state_dict for DETRsegm: HOT 2
- Map metrics anomalies after backbone replacement
- when the trained model is used for inference this import error comes: RuntimeError: Failed to import transformers.models.detr.modeling_detr because of the following error (look up to see its traceback): cannot import name 'experimental_functions_run_eagerly' from 'tensorflow.python.eager.def_function' (C:\Anaconda\lib\site-packages\tensorflow\python\eager\def_function.py)
- Get Image masks coordinates.
- GFLOPs instead of GFLOPS?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from detr.