Git Product home page Git Product logo

dynamicvit's Introduction

Efficient Vision Transformers and CNNs with Dynamic Spatial Sparsification

This repository contains PyTorch implementation for DynamicViT (NeurIPS 2021).

DynamicViT is a dynamic token sparsification framework to prune redundant tokens in vision transformers progressively and dynamically based on the input. Our method can reduces over 30% FLOPs and improves the throughput by over 40% while the drop of accuracy is within 0.5% for various vision transformers.

[Project Page] [arXiv (NeurIPS 2021)]

πŸ”₯Updates

We extend our method to more network architectures (i.e., ConvNeXt and Swin Transformers) and more tasks (i.e., object detection and semantic segmentation) with an improved dynamic spatial sparsification framework. Please refer to the extended version of our paper for details. The extended version has been accepted by T-PAMI.

[arXiv (T-PAMI, Journal Version)]

Image Examples

intro


Video Examples

result1

Model Zoo

We provide our DynamicViT models pretrained on ImageNet:

name model rho acc@1 acc@5 FLOPs url
DynamicViT-DeiT-256/0.7 deit-256 0.7 76.53 93.12 1.3G Google Drive / Tsinghua Cloud
DynamicViT-DeiT-S/0.7 deit-s 0.7 79.32 94.68 2.9G Google Drive / Tsinghua Cloud
DynamicViT-DeiT-B/0.7 deit-b 0.7 81.43 95.46 11.4G Google Drive / Tsinghua Cloud
DynamicViT-LVViT-S/0.5 lvvit-s 0.5 81.97 95.76 3.7G Google Drive / Tsinghua Cloud
DynamicViT-LVViT-S/0.7 lvvit-s 0.7 83.08 96.25 4.6G Google Drive / Tsinghua Cloud
DynamicViT-LVViT-M/0.7 lvvit-m 0.7 83.82 96.58 8.5G Google Drive / Tsinghua Cloud

πŸ”₯Updates: We provide our DynamicCNN and DynamicSwin models pretrained on ImageNet:

name model rho acc@1 acc@5 FLOPs url
DynamicCNN-T/0.7 convnext-t 0.7 81.59 95.72 3.6G Google Drive / Tsinghua Cloud
DynamicCNN-T/0.9 convnext-t 0.9 82.06 95.89 3.9G Google Drive / Tsinghua Cloud
DynamicCNN-S/0.7 convnext-s 0.7 82.57 96.29 5.8G Google Drive / Tsinghua Cloud
DynamicCNN-S/0.9 convnext-s 0.9 83.12 96.42 6.8G Google Drive / Tsinghua Cloud
DynamicCNN-B/0.7 convnext-b 0.7 83.45 96.56 10.2G Google Drive / Tsinghua Cloud
DynamicCNN-B/0.9 convnext-b 0.9 83.96 96.76 11.9G Google Drive / Tsinghua Cloud
DynamicSwin-T/0.7 swin-t 0.7 80.91 95.42 4.0G Google Drive / Tsinghua Cloud
DynamicSwin-S/0.7 swin-s 0.7 83.21 96.33 6.9G Google Drive / Tsinghua Cloud
DynamicSwin-B/0.7 swin-b 0.7 83.43 96.45 12.1G Google Drive / Tsinghua Cloud

Usage

Requirements

  • torch>=1.8.0
  • torchvision>=0.9.0
  • timm==0.3.2
  • tensorboardX
  • six
  • fvcore

Data preparation: download and extract ImageNet images from http://image-net.org/. The directory structure should be

β”‚ILSVRC2012/
β”œβ”€β”€train/
β”‚  β”œβ”€β”€ n01440764
β”‚  β”‚   β”œβ”€β”€ n01440764_10026.JPEG
β”‚  β”‚   β”œβ”€β”€ n01440764_10027.JPEG
β”‚  β”‚   β”œβ”€β”€ ......
β”‚  β”œβ”€β”€ ......
β”œβ”€β”€val/
β”‚  β”œβ”€β”€ n01440764
β”‚  β”‚   β”œβ”€β”€ ILSVRC2012_val_00000293.JPEG
β”‚  β”‚   β”œβ”€β”€ ILSVRC2012_val_00002138.JPEG
β”‚  β”‚   β”œβ”€β”€ ......
β”‚  β”œβ”€β”€ ......

Model preparation: download pre-trained models if necessary:

model url model url
DeiT-Small link LVViT-S link
DeiT-Base link LVViT-M link
ConvNeXt-T link Swin-T link
ConvNeXt-S link Swin-S link
ConvNeXt-B link Swin-B link

Demo

You can try DynamicViT on Colab . Thank @dirtycomputer for the contribution.

We also provide a Jupyter notebook where you can run the visualization of DynamicViT.

To run the demo, you need to install matplotlib.

demo

Evaluation

To evaluate a pre-trained DynamicViT model on the ImageNet validation set with a single GPU, run:

python infer.py --data_path /path/to/ILSVRC2012/ --model model_name \
--model_path /path/to/model --base_rate 0.7 

Training

To train Dynamic Spatial Sparsification models on ImageNet, run:

(You can train models with different keeping ratio by adjusting base_rate. )

DeiT-S

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --output_dir logs/dynamicvit_deit-s --model deit-s --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 30 --base_rate 0.7 --lr 1e-3 --warmup_epochs 5

DeiT-B

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --output_dir logs/dynamicvit_deit-b --model deit-b --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 30 --base_rate 0.7 --lr 1e-3 --warmup_epochs 5 --drop_path 0.2 --ratio_weight 5.0

LV-ViT-S

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --output_dir logs/dynamicvit_lvvit-s --model lvvit-s --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 30 --base_rate 0.7 --lr 1e-3 --warmup_epochs 5

LV-ViT-M

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --output_dir logs/dynamicvit_lvvit-m --model lvvit-m --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 30 --base_rate 0.7 --lr 1e-3 --warmup_epochs 5

DynamicViT can also achieve comparable performance with only 15 epochs training (around 0.1% lower accuracy compared to 30 epochs).

ConvNeXt-T

Train on 8 GPUs:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --output_dir logs/dynamic_conv-t --model convnext-t --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 120 --base_rate 0.7 --lr 4e-3 --drop_path 0.2 --update_freq 4 --lr_scale 0.2

Train on 4 8-GPU nodes:

python run_with_submitit.py --nodes 4 --ngpus 8 --output_dir logs/dynamic_conv-t --model convnext-t --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 120 --base_rate 0.7 --lr 4e-3 --drop_path 0.2 --update_freq 1 --lr_scale 0.2

ConvNeXt-S

Train on 8 GPUs:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --output_dir logs/dynamic_conv-s --model convnext-s --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 120 --base_rate 0.7 --lr 4e-3 --drop_path 0.2 --update_freq 4 --lr_scale 0.2

Train on 4 8-GPU nodes:

python run_with_submitit.py --nodes 4 --ngpus 8 --output_dir logs/dynamic_conv-s --model convnext-s --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 120 --base_rate 0.7 --lr 4e-3 --drop_path 0.2 --update_freq 1 --lr_scale 0.2

ConvNeXt-B

Train on 8 GPUs:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --output_dir logs/dynamic_conv-b --model convnext-b --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 120 --base_rate 0.7 --lr 4e-3 --drop_path 0.5 --update_freq 4 --lr_scale 0.2

Train on 4 8-GPU nodes:

python run_with_submitit.py --nodes 4 --ngpus 8 --output_dir logs/dynamic_conv-b --model convnext-b --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 120 --base_rate 0.7 --lr 4e-3 --drop_path 0.5 --update_freq 1 --lr_scale 0.2

Swin-T

Train on 8 GPUs:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --output_dir logs/dynamic_swin-t --model swin-t --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 120 --base_rate 0.7 --lr 4e-3 --drop_path 0.2 --update_freq 4 --lr_scale 0.2

Train on 4 8-GPU nodes:

python run_with_submitit.py --nodes 4 --ngpus 8 --output_dir logs/dynamic_swin-t --model swin-t --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 120 --base_rate 0.7 --lr 4e-3 --drop_path 0.2 --update_freq 1 --lr_scale 0.2

Swin-S

Train on 8 GPUs:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --output_dir logs/dynamic_swin-s --model swin-s --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 120 --base_rate 0.7 --lr 4e-3 --drop_path 0.2 --update_freq 4 --lr_scale 0.2

Train on 4 8-GPU nodes:

python run_with_submitit.py --nodes 4 --ngpus 8 --output_dir logs/dynamic_swin-s --model swin-s --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 120 --base_rate 0.7 --lr 4e-3 --drop_path 0.2 --update_freq 1 --lr_scale 0.2

Swin-B

Train on 8 GPUs:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --output_dir logs/dynamic_swin-b --model swin-b --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 120 --base_rate 0.7 --lr 4e-3 --drop_path 0.5 --update_freq 4 --lr_scale 0.2

Train on 4 8-GPU nodes:

python run_with_submitit.py --nodes 4 --ngpus 8 --output_dir logs/dynamic_swin-b --model swin-b --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 120 --base_rate 0.7 --lr 4e-3 --drop_path 0.5 --update_freq 1 --lr_scale 0.2

License

MIT License

Acknowledgements

Our code is based on pytorch-image-models, DeiT, LV-ViT, ConvNeXt and Swin-Transformer.

Citation

If you find our work useful in your research, please consider citing:

@inproceedings{rao2021dynamicvit,
  title={DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification},
  author={Rao, Yongming and Zhao, Wenliang and Liu, Benlin and Lu, Jiwen and Zhou, Jie and Hsieh, Cho-Jui},
  booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
  year = {2021}
}
@article{rao2022dynamicvit,
  title={Dynamic Spatial Sparsification for Efficient Vision Transformers and Convolutional Neural Networks},
  author={Rao, Yongming and Liu, Zuyan and Zhao, Wenliang and Zhou, Jie and Lu, Jiwen},
  journal={arXiv preprint arXiv:2207.01580},
  year={2022}

dynamicvit's People

Contributors

raoyongming avatar wl-zhao avatar johnwahlig avatar liuzuyan avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.