Git Product home page Git Product logo

uvc's Introduction

Unified Vision Transformer Compression

License: MIT

Codes for the paper: [ICLR 2022] Unified Vision Transformer Compression.

Shixing Yu*, Tianlong Chen*, Jiayi Shen, Huan Yuan, Jianchao Tan, Sen Yang, Ji Liu, Zhangyang Wang

Overall Results

Extensive experiments are conducted with several DeiT backbones on ImageNet, which consistently verify the effectiveness of our proposal. For example, UVC on DeiT-Tiny (with/without distillation tokens) yields around 50% FLOPs reduction, with little performance degradation (only 0.3%/0.9% loss compared to the baseline).

Method Acc FLOPs(G) Compression Ratio (%)
DeiT-Small 79.8 4.6 100
SCOP 77.5 (-2.3) 2.6 56.4
PoWER 78.3 (-1.5) 2.7 58.7
HVT 78.0 (-1.8) 2.4 52.2
Patch Slimming 79.4 (-0.4) 2.6 56.5
UVC (Ours) 79.44 (-0.36) 2.65 57.61
UVC (Ours) 78.82 (-0.98) 2.32 50.41

Overview of Proposed UVC

We formulate and solve UVC as a unified constrained optimization problem. It simultaneously learns model weights, layer-wise pruning ratios/masks, and skip configurations, under a distillation loss and an overall budget constraint.

architecture

Implementations of UVC

Set the Environment

conda create -n vit python=3.6

pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html

pip install tqdm scipy timm
pip install ml_collections
pip install tensorboard

git clone https://github.com/NVIDIA/apex

cd apex

pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

pip install -v --disable-pip-version-check --no-cache-dir ./

Running command

The training contains two parts.

  • The first part is UVC Training. In this stage, it optimizes the architecture with primal-dual algorithm to find the optimal block-wise layout and skip configuration.
  • The second part is Post Training. In this stage, the architecture is fixed while only updating the weights to help the network to regain accuracy.

Stage1 UVC Training

python -W ignore -m torch.distributed.launch \
--nproc_per_node=2 \
--master_port 6019 joint_train.py \
--gpu_num '0,1' \
--uvc_train \
--model_type deit_tiny_patch16_224 \
--model_path https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth \
--distillation-type soft \
--distillation-alpha 0.1 \
--train_batch_size 512 \
--num_epochs 30 \
--eval_every 1000 \
--flops_with_mhsa 1 \
--zlr_schedule_list "1,5,9,13,17" \
--learning_rate 1e-4 \
--enable_deit 0 \
--budget 0.5 \
--enable_pruning 1 \
--enable_block_gating 1 \
--enable_patch_gating 1 \
--gating_weight 5e-4 \
--patch_weight 5 \
--patch_l1_weight 0.01 \
--patchloss "l1" \
--use_gumbel 1 \
--glr 0.1 \
--patchlr 0.01 \
--num_workers 64 \
--seed 730 \
--output_dir mc_deit_tiny_patch16_224_with_patch \
--log_interval 1000 \
--eps 0.1 \
--eps_decay 0.92 \
--enable_warmup 1 \
--warmup_epochs 5 \
--warmup_lr 1e-4 \
--z_grad_clip 0.5 \
--gating_interval 50

Stage2 Post Training

python -m torch.distributed.launch \
--nproc_per_node=2 --master_port 6382 post_train.py \
--pretrained 0 \
--model_type "deit_small_patch16_224" \
--model_path https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth \
--checkpoint_dir /home/shixing/deit_small_patch16_224_11.pth.tar \
--distillation-type soft \
--distillation-alpha 0.1 \
--train_batch_size 256 \
--gpu_num '2,3' \
--epochs 120 \
--eval_every 1000 \
--output_dir exp/deit_small_nasprune_0.58 \
--num_workers 64

Citation

@inproceedings{yu2022unified,
  author = {Yu, Shixing and Chen, Tianlong and Shen, Jiayi and Yuan, Huan and Tan, Jianchao and Yang, Sen and Liu, Ji and Wang, Zhangyang},
  title = {Unified Visual Transformer Compression},
  booktitle = {ICLR},
  year = {2022},
}

Results

deit-tiny-distilled-patch16-224

Acknowledgement

ViT : https://github.com/jeonsworld/ViT-pytorch

ViT : https://github.com/google-research/vision_transformer

DeiT: https://github.com/facebookresearch/deit

T2T-ViT: https://github.com/yitu-opensource/T2T-ViT

uvc's People

Contributors

billysx avatar tianlong-chen avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

uvc's Issues

代码有问题

依赖都已经安装,不存在依赖的问题,但是代码运行时还是各种乱七八糟的错,有些错误在编辑器里就能明显看到,希望上传能够正确运行的代码,以及作者所展示的最优结果的log文件。

Pre-trained model checkpoint

Awesome work !
I want to ask whether possible to provide pre-trained model checkpoint ?
Thank you very much !!!

Unable to reproduce the code.

Hello, I have read the paper and this is a great work on vision transformer compression!

However, when I want to reproduce the code following the command on github. I always get 0 accuracy after joint training. Could you provide the pre-trained model checkpoint so I can easily evaluate it on ImageNet?

I simply followed the command on github for setting the environment and run the stage 1 uvc training command. I installed apex library without global option. I used 2 NVIDIA GeForce RTX 2080 GPU to run it.

image

Best regards,

Question on the result of DeiT-base

Hi, thanks for sharing the nice work. Will the result model weight of DeiT-base been released ?
Besides, could you please provide the setting of Deit-B for rebuilding your result?

Best regard

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.