Git Product home page Git Product logo

dpt's Introduction

DPT


This repo is the official implementation of DPT: Deformable Patch-based Transformer for Visual Recognition (ACM MM2021). We provide code and models for the following tasks:

Image Classification: Detailed instruction and information see classification/README.md.

Object Detection: Detailed instruction and information see detection/README.md.

The papar has been relased on [Arxiv].

Introduction

Deformable Patch (DePatch) is a plug-and-play module. It learns to adaptively split the images input patches with different positions and scales in a data-driven way, rather than using predefined fixed patches. In this way, our method can well preserve the semantics in patches.

In this repository, code and models for a Deformable Patch-based Transformer (DPT) are provided. As this field is developing rapidly, we are willing to see our DePatch applied to some other latest architectures and promote further research.

Main Results

Image Classification

Training commands and pretrained models are provided >>> here <<<.

Method #Params (M) FLOPs(G) Acc@1
DPT-Tiny 15.2 2.1 77.4
DPT-Small 26.4 4.0 81.0
DPT-Medium 46.1 6.9 81.9

Object Detection

Training command and detailed results are provided >>> here <<<.

Citation

@inproceedings{chenDPT21,
  title = {DPT: Deformable Patch-based Transformer for Visual Recognition},
  author = {Zhiyang Chen and Yousong Zhu and Chaoyang Zhao and Guosheng Hu and Wei Zeng and Jinqiao Wang and Ming Tang},
  booktitle={Proceedings of the ACM International Conference on Multimedia},
  year={2021}
}

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Acknowledgement

Our implementation is mainly based on PVT. The CUDA operator is borrowed from Deformable-DETR. You may refer these repositories for further information.

dpt's People

Contributors

ivaopenlab avatar volgachen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

dpt's Issues

Visualization of boxes

Hi @volgachen,

I follow this code #10 (comment) to visualize the boxes.
Due to the amplifier=2, I resize the image to (224*2)X(224*2) and plot the boxes, centers and anchors in the image.
In this figure, the red points are anchors, the blue points are centers of boxes, and red rectangles are the boxes.

But I found that some boxes are outside the image.
I am wondering whether this is caused by the '16' in the following two lines in the code?

  1. boxes = (box_coder.boxes[0] * 224 + 16 ) * amplifier
  2. anchors = [(patch_embed.box_coder.anchor * 224 + 16)*amplifier for patch_embed in patch_embeds]

Or do I need to resize image to (256*2)X(256*2) during visualization since it preprocesses images by resizing to 256 and centercropping to 224?

image

Thank you very much.
Looking forward to further discussion with you.

getting error in MultiScaleDeformableAttention

Hi ,
I am getting below error.

ImportError: /home/shubham/anaconda3/envs/dpt1/lib/python3.8/site-packages/MultiScaleDeformableAttention-1.0-py3.8-linux-x86_64.egg/MultiScaleDeformableAttention.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZNK2at6Tensor7optionsEv

torch.version = 1.10.0.dev20210812+cu111
GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0

How to use DPT in ViT?

I want to perform it in super resolution.Could you please help me solve the problem or give me some tips?Thank you very much
.

The given model path for small and medium failed in loading state_dict

Have this issue when I test the model zoo for small and medium with given path. Can you help me to check?

Traceback (most recent call last):
File "main.py", line 459, in
main(args)
File "main.py", line 376, in main
model_without_ddp.load_state_dict(checkpoint)
File "/home/-/anaconda3/envs/dpt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for DeformablePatchTransformer:
Unexpected key(s) in state_dict: "block1.2.norm1.weight", "block1.2.norm1.bias", "block1.2.attn.q.weight", "block1.2.attn.q.bias", "block1.2.attn.kv.weight", "block1.2.attn.kv.bias", "block1.2.attn.proj.weight", "block1.2.attn.proj.bias", "block1.2.attn.sr.weight", "block1.2.attn.sr.bias", "block1.2.attn.norm.weight", "block1.2.attn.norm.bias", "block1.2.norm2.weight", "block1.2.norm2.bias", "block1.2.mlp.fc1.weight", "block1.2.mlp.fc1.bias", "block1.2.mlp.fc2.weight", "block1.2.mlp.fc2.bias", "block2.2.norm1.weight", "block2.2.norm1.bias", "block2.2.attn.q.weight", "block2.2.attn.q.bias", "block2.2.attn.kv.weight", "block2.2.attn.kv.bias", "block2.2.attn.proj.weight", "block2.2.attn.proj.bias", "block2.2.attn.sr.weight", "block2.2.attn.sr.bias", "block2.2.attn.norm.weight", "block2.2.attn.norm.bias", "block2.2.norm2.weight", "block2.2.norm2.bias", "block2.2.mlp.fc1.weight", "block2.2.mlp.fc1.bias", "block2.2.mlp.fc2.weight", "block2.2.mlp.fc2.bias", "block2.3.norm1.weight", "block2.3.norm1.bias", "block2.3.attn.q.weight", "block2.3.attn.q.bias", "block2.3.attn.kv.weight", "block2.3.attn.kv.bias", "block2.3.attn.proj.weight", "block2.3.attn.proj.bias", "block2.3.attn.sr.weight", "block2.3.attn.sr.bias", "block2.3.attn.norm.weight", "block2.3.attn.norm.bias", "block2.3.norm2.weight", "block2.3.norm2.bias", "block2.3.mlp.fc1.weight", "block2.3.mlp.fc1.bias", "block2.3.mlp.fc2.weight", "block2.3.mlp.fc2.bias", "block3.2.norm1.weight", "block3.2.norm1.bias", "block3.2.attn.q.weight", "block3.2.attn.q.bias", "block3.2.attn.kv.weight", "block3.2.attn.kv.bias", "block3.2.attn.proj.weight", "block3.2.attn.proj.bias", "block3.2.attn.sr.weight", "block3.2.attn.sr.bias", "block3.2.attn.norm.weight", "block3.2.attn.norm.bias", "block3.2.norm2.weight", "block3.2.norm2.bias", "block3.2.mlp.fc1.weight", "block3.2.mlp.fc1.bias", "block3.2.mlp.fc2.weight", "block3.2.mlp.fc2.bias", "block3.3.norm1.weight", "block3.3.norm1.bias", "block3.3.attn.q.weight", "block3.3.attn.q.bias", "block3.3.attn.kv.weight", "block3.3.attn.kv.bias", "block3.3.attn.proj.weight", "block3.3.attn.proj.bias", "block3.3.attn.sr.weight", "block3.3.attn.sr.bias", "block3.3.attn.norm.weight", "block3.3.attn.norm.bias", "block3.3.norm2.weight", "block3.3.norm2.bias", "block3.3.mlp.fc1.weight", "block3.3.mlp.fc1.bias", "block3.3.mlp.fc2.weight", "block3.3.mlp.fc2.bias", "block3.4.norm1.weight", "block3.4.norm1.bias", "block3.4.attn.q.weight", "block3.4.attn.q.bias", "block3.4.attn.kv.weight", "block3.4.attn.kv.bias", "block3.4.attn.proj.weight", "block3.4.attn.proj.bias", "block3.4.attn.sr.weight", "block3.4.attn.sr.bias", "block3.4.attn.norm.weight", "block3.4.attn.norm.bias", "block3.4.norm2.weight", "block3.4.norm2.bias", "block3.4.mlp.fc1.weight", "block3.4.mlp.fc1.bias", "block3.4.mlp.fc2.weight", "block3.4.mlp.fc2.bias", "block3.5.norm1.weight", "block3.5.norm1.bias", "block3.5.attn.q.weight", "block3.5.attn.q.bias", "block3.5.attn.kv.weight", "block3.5.attn.kv.bias", "block3.5.attn.proj.weight", "block3.5.attn.proj.bias", "block3.5.attn.sr.weight", "block3.5.attn.sr.bias", "block3.5.attn.norm.weight", "block3.5.attn.norm.bias", "block3.5.norm2.weight", "block3.5.norm2.bias", "block3.5.mlp.fc1.weight", "block3.5.mlp.fc1.bias", "block3.5.mlp.fc2.weight", "block3.5.mlp.fc2.bias", "block3.6.norm1.weight", "block3.6.norm1.bias", "block3.6.attn.q.weight", "block3.6.attn.q.bias", "block3.6.attn.kv.weight", "block3.6.attn.kv.bias", "block3.6.attn.proj.weight", "block3.6.attn.proj.bias", "block3.6.attn.sr.weight", "block3.6.attn.sr.bias", "block3.6.attn.norm.weight", "block3.6.attn.norm.bias", "block3.6.norm2.weight", "block3.6.norm2.bias", "block3.6.mlp.fc1.weight", "block3.6.mlp.fc1.bias", "block3.6.mlp.fc2.weight", "block3.6.mlp.fc2.bias", "block3.7.norm1.weight", "block3.7.norm1.bias", "block3.7.attn.q.weight", "block3.7.attn.q.bias", "block3.7.attn.kv.weight", "block3.7.attn.kv.bias", "block3.7.attn.proj.weight", "block3.7.attn.proj.bias", "block3.7.attn.sr.weight", "block3.7.attn.sr.bias", "block3.7.attn.norm.weight", "block3.7.attn.norm.bias", "block3.7.norm2.weight", "block3.7.norm2.bias", "block3.7.mlp.fc1.weight", "block3.7.mlp.fc1.bias", "block3.7.mlp.fc2.weight", "block3.7.mlp.fc2.bias", "block3.8.norm1.weight", "block3.8.norm1.bias", "block3.8.attn.q.weight", "block3.8.attn.q.bias", "block3.8.attn.kv.weight", "block3.8.attn.kv.bias", "block3.8.attn.proj.weight", "block3.8.attn.proj.bias", "block3.8.attn.sr.weight", "block3.8.attn.sr.bias", "block3.8.attn.norm.weight", "block3.8.attn.norm.bias", "block3.8.norm2.weight", "block3.8.norm2.bias", "block3.8.mlp.fc1.weight", "block3.8.mlp.fc1.bias", "block3.8.mlp.fc2.weight", "block3.8.mlp.fc2.bias", "block3.9.norm1.weight", "block3.9.norm1.bias", "block3.9.attn.q.weight", "block3.9.attn.q.bias", "block3.9.attn.kv.weight", "block3.9.attn.kv.bias", "block3.9.attn.proj.weight", "block3.9.attn.proj.bias", "block3.9.attn.sr.weight", "block3.9.attn.sr.bias", "block3.9.attn.norm.weight", "block3.9.attn.norm.bias", "block3.9.norm2.weight", "block3.9.norm2.bias", "block3.9.mlp.fc1.weight", "block3.9.mlp.fc1.bias", "block3.9.mlp.fc2.weight", "block3.9.mlp.fc2.bias", "block3.10.norm1.weight", "block3.10.norm1.bias", "block3.10.attn.q.weight", "block3.10.attn.q.bias", "block3.10.attn.kv.weight", "block3.10.attn.kv.bias", "block3.10.attn.proj.weight", "block3.10.attn.proj.bias", "block3.10.attn.sr.weight", "block3.10.attn.sr.bias", "block3.10.attn.norm.weight", "block3.10.attn.norm.bias", "block3.10.norm2.weight", "block3.10.norm2.bias", "block3.10.mlp.fc1.weight", "block3.10.mlp.fc1.bias", "block3.10.mlp.fc2.weight", "block3.10.mlp.fc2.bias", "block3.11.norm1.weight", "block3.11.norm1.bias", "block3.11.attn.q.weight", "block3.11.attn.q.bias", "block3.11.attn.kv.weight", "block3.11.attn.kv.bias", "block3.11.attn.proj.weight", "block3.11.attn.proj.bias", "block3.11.attn.sr.weight", "block3.11.attn.sr.bias", "block3.11.attn.norm.weight", "block3.11.attn.norm.bias", "block3.11.norm2.weight", "block3.11.norm2.bias", "block3.11.mlp.fc1.weight", "block3.11.mlp.fc1.bias", "block3.11.mlp.fc2.weight", "block3.11.mlp.fc2.bias", "block3.12.norm1.weight", "block3.12.norm1.bias", "block3.12.attn.q.weight", "block3.12.attn.q.bias", "block3.12.attn.kv.weight", "block3.12.attn.kv.bias", "block3.12.attn.proj.weight", "block3.12.attn.proj.bias", "block3.12.attn.sr.weight", "block3.12.attn.sr.bias", "block3.12.attn.norm.weight", "block3.12.attn.norm.bias", "block3.12.norm2.weight", "block3.12.norm2.bias", "block3.12.mlp.fc1.weight", "block3.12.mlp.fc1.bias", "block3.12.mlp.fc2.weight", "block3.12.mlp.fc2.bias", "block3.13.norm1.weight", "block3.13.norm1.bias", "block3.13.attn.q.weight", "block3.13.attn.q.bias", "block3.13.attn.kv.weight", "block3.13.attn.kv.bias", "block3.13.attn.proj.weight", "block3.13.attn.proj.bias", "block3.13.attn.sr.weight", "block3.13.attn.sr.bias", "block3.13.attn.norm.weight", "block3.13.attn.norm.bias", "block3.13.norm2.weight", "block3.13.norm2.bias", "block3.13.mlp.fc1.weight", "block3.13.mlp.fc1.bias", "block3.13.mlp.fc2.weight", "block3.13.mlp.fc2.bias", "block3.14.norm1.weight", "block3.14.norm1.bias", "block3.14.attn.q.weight", "block3.14.attn.q.bias", "block3.14.attn.kv.weight", "block3.14.attn.kv.bias", "block3.14.attn.proj.weight", "block3.14.attn.proj.bias", "block3.14.attn.sr.weight", "block3.14.attn.sr.bias", "block3.14.attn.norm.weight", "block3.14.attn.norm.bias", "block3.14.norm2.weight", "block3.14.norm2.bias", "block3.14.mlp.fc1.weight", "block3.14.mlp.fc1.bias", "block3.14.mlp.fc2.weight", "block3.14.mlp.fc2.bias", "block3.15.norm1.weight", "block3.15.norm1.bias", "block3.15.attn.q.weight", "block3.15.attn.q.bias", "block3.15.attn.kv.weight", "block3.15.attn.kv.bias", "block3.15.attn.proj.weight", "block3.15.attn.proj.bias", "block3.15.attn.sr.weight", "block3.15.attn.sr.bias", "block3.15.attn.norm.weight", "block3.15.attn.norm.bias", "block3.15.norm2.weight", "block3.15.norm2.bias", "block3.15.mlp.fc1.weight", "block3.15.mlp.fc1.bias", "block3.15.mlp.fc2.weight", "block3.15.mlp.fc2.bias", "block3.16.norm1.weight", "block3.16.norm1.bias", "block3.16.attn.q.weight", "block3.16.attn.q.bias", "block3.16.attn.kv.weight", "block3.16.attn.kv.bias", "block3.16.attn.proj.weight", "block3.16.attn.proj.bias", "block3.16.attn.sr.weight", "block3.16.attn.sr.bias", "block3.16.attn.norm.weight", "block3.16.attn.norm.bias", "block3.16.norm2.weight", "block3.16.norm2.bias", "block3.16.mlp.fc1.weight", "block3.16.mlp.fc1.bias", "block3.16.mlp.fc2.weight", "block3.16.mlp.fc2.bias", "block3.17.norm1.weight", "block3.17.norm1.bias", "block3.17.attn.q.weight", "block3.17.attn.q.bias", "block3.17.attn.kv.weight", "block3.17.attn.kv.bias", "block3.17.attn.proj.weight", "block3.17.attn.proj.bias", "block3.17.attn.sr.weight", "block3.17.attn.sr.bias", "block3.17.attn.norm.weight", "block3.17.attn.norm.bias", "block3.17.norm2.weight", "block3.17.norm2.bias", "block3.17.mlp.fc1.weight", "block3.17.mlp.fc1.bias", "block3.17.mlp.fc2.weight", "block3.17.mlp.fc2.bias", "block4.2.norm1.weight", "block4.2.norm1.bias", "block4.2.attn.q.weight", "block4.2.attn.q.bias", "block4.2.attn.kv.weight", "block4.2.attn.kv.bias", "block4.2.attn.proj.weight", "block4.2.attn.proj.bias", "block4.2.norm2.weight", "block4.2.norm2.bias", "block4.2.mlp.fc1.weight", "block4.2.mlp.fc1.bias", "block4.2.mlp.fc2.weight", "block4.2.mlp.fc2.bias".
Traceback (most recent call last):
File "/home/-/anaconda3/envs/dpt/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/-/anaconda3/envs/dpt/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/-/anaconda3/envs/dpt/lib/python3.8/site-packages/torch/distributed/launch.py", line 260, in
main()
File "/home/-/anaconda3/envs/dpt/lib/python3.8/site-packages/torch/distributed/launch.py", line 255, in main
raise subprocess.CalledProcessError(returncode=process.returncode,
subprocess.CalledProcessError: Command '['/home/cvlab/anaconda3/envs/dpt/bin/python', '-u', 'main.py', '--eval', '--model', 'dpt_tiny', '--data-path', '/home/cvlab/datasets/ImageNet', '--resume', 'dpt_medium.pth']' returned non-zero exit status 1.

RuntimeError:

image Sorry to bother you.. But I met a problem which could not solve. image How can I solve this problem? Just install cuda 11.3? Watch for your reply.

extract the patch by calling MSDeformAttnFunction function

Thanks for sharing your work~
I find in file depatch_embed.py (L112) will call this line:

output = MSDeformAttnFunction.apply(x, value_spatial_shapes, self.value_level_start_index, sampling_locations, attention_weights, 1)
I assume that this code may produce extra deformable attention calculation. If it is true, this will produce extra computations and it is fair to compare to the PVT ?

How to use DPT in DETR

Does the author add the DPT module to the DETR part? A little anxious, thank you very much

error: import MultiScaleDeformableAttention as MSDA ----solved

1.at first, need to install successfully
-- sh ./make.sh
......
Processing dependencies for MultiScaleDeformableAttention==1.0
Finished processing dependencies for MultiScaleDeformableAttention==1.0

2.meet error as following:
import MultiScaleDeformableAttention as MSDA
ImportError: libcudart.so.10.0: cannot open shared object file: No such file or directory

3.method:
sudo ldconfig /usr/local/cuda-10.0/lib64
(version: cuda-10.0)

4.successfully

About the patch size

Hi @volgachen @ivaopenlab
Thank you for sharing your code.
I found that your patch size is set as 4 in dpt.py line 369.
Is this correct? Or is there other meaninng about it?

model = DeformablePatchTransformer(

model = DeformablePatchTransformer(
        patch_size=4, embed_dims=embed_dims, num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
        patch_embeds=patch_embeds,
        **kwargs)

Looking forward to further discussion with you.
Thank you very much.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.