Git Product home page Git Product logo

repvit's Introduction

Official PyTorch implementation of RepViT-SAM and RepViT. CVPR 2024.


Models are deployed on iPhone 12 with Core ML Tools to get latency.


Models are trained on ImageNet-1K and deployed on iPhone 12 with Core ML Tools to get latency.

RepViT-SAM: Towards Real-Time Segmenting Anything.
Ao Wang, Hui Chen, Zijia Lin, Jungong Han, and Guiguang Ding
[arXiv] [Project Page]

Abstract Segment Anything Model (SAM) has shown impressive zero-shot transfer performance for various computer vision tasks recently. However, its heavy computation costs remain daunting for practical applications. MobileSAM proposes to replace the heavyweight image encoder in SAM with TinyViT by employing distillation, which results in a significant reduction in computational requirements. However, its deployment on resource-constrained mobile devices still encounters challenges due to the substantial memory and computational overhead caused by self-attention mechanisms. Recently, RepViT achieves the state-of-the-art performance and latency trade-off on mobile devices by incorporating efficient architectural designs of ViTs into CNNs. Here, to achieve real-time segmenting anything on mobile devices, following, we replace the heavyweight image encoder in SAM with RepViT model, ending up with the RepViT-SAM model. Extensive experiments show that RepViT-SAM can enjoy significantly better zero-shot transfer capability than MobileSAM, along with nearly $10\times$ faster inference speed.

RepViT: Revisiting Mobile CNN From ViT Perspective.
Ao Wang, Hui Chen, Zijia Lin, Jungong Han, and Guiguang Ding
[arXiv]

Abstract Recently, lightweight Vision Transformers (ViTs) demonstrate superior performance and lower latency compared with lightweight Convolutional Neural Networks (CNNs) on resource-constrained mobile devices. This improvement is usually attributed to the multi-head self-attention module, which enables the model to learn global representations. However, the architectural disparities between lightweight ViTs and lightweight CNNs have not been adequately examined. In this study, we revisit the efficient design of lightweight CNNs and emphasize their potential for mobile devices. We incrementally enhance the mobile-friendliness of a standard lightweight CNN, specifically MobileNetV3, by integrating the efficient architectural choices of lightweight ViTs. This ends up with a new family of pure lightweight CNNs, namely RepViT. Extensive experiments show that RepViT outperforms existing state-of-the-art lightweight ViTs and exhibits favorable latency in various vision tasks. On ImageNet, RepViT achieves over 80\% top-1 accuracy with 1ms latency on an iPhone 12, which is the first time for a lightweight model, to the best of our knowledge. Our largest model, RepViT-M2.3, obtains 83.7\% accuracy with only 2.3ms latency.


UPDATES ๐Ÿ”ฅ

  • 2023/12/17: Grounding-SAM supports RepViT-SAM with Grounded-RepViT-SAM. Thanks!
  • 2023/12/11: RepViT-SAM has been released. Please refer to RepViT-SAM.
  • 2023/12/11: RepViT-M0.6 has been released, achieving 74.1% with ~0.6ms latency. Its checkpoint is here
  • 2023/09/28: RepViT-M0.9/1.0/1.1/1.5/2.3 models have been released.
  • 2023/07/27: RepViT models have been integrated into timm. See huggingface/pytorch-image-models#1876.

Classification on ImageNet-1K

Models

Model Top-1 (300 / 450) #params MACs Latency Ckpt Core ML Log
M0.9 78.7 / 79.1 5.1M 0.8G 0.9ms 300e / 450e 300e / 450e 300e / 450e
M1.0 80.0 / 80.3 6.8M 1.1G 1.0ms 300e / 450e 300e / 450e 300e / 450e
M1.1 80.7 / 81.2 8.2M 1.3G 1.1ms 300e / 450e 300e / 450e 300e / 450e
M1.5 82.3 / 82.5 14.0M 2.3G 1.5ms 300e / 450e 300e / 450e 300e / 450e
M2.3 83.3 / 83.7 22.9M 4.5G 2.3ms 300e / 450e 300e / 450e 300e / 450e

Tips: Convert a training-time RepViT into the inference-time structure

from timm.models import create_model
import utils

model = create_model('repvit_m0_9')
utils.replace_batchnorm(model)

Latency Measurement

The latency reported in RepViT for iPhone 12 (iOS 16) uses the benchmark tool from XCode 14. For example, here is a latency measurement of RepViT-M0.9:

Tips: export the model to Core ML model

python export_coreml.py --model repvit_m0_9 --ckpt pretrain/repvit_m0_9_distill_300e.pth

Tips: measure the throughput on GPU

python speed_gpu.py --model repvit_m0_9

ImageNet

Prerequisites

conda virtual environment is recommended.

conda create -n repvit python=3.8
pip install -r requirements.txt

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The training and validation data are expected to be in the train folder and val folder respectively:

|-- /path/to/imagenet/
    |-- train
    |-- val

Training

To train RepViT-M0.9 on an 8-GPU machine:

python -m torch.distributed.launch --nproc_per_node=8 --master_port 12346 --use_env main.py --model repvit_m0_9 --data-path ~/imagenet --dist-eval

Tips: specify your data path and model name!

Testing

For example, to test RepViT-M0.9:

python main.py --eval --model repvit_m0_9 --resume pretrain/repvit_m0_9_distill_300e.pth --data-path ~/imagenet

Downstream Tasks

Object Detection and Instance Segmentation
Semantic Segmentation

Acknowledgement

Classification (ImageNet) code base is partly built with LeViT, PoolFormer and EfficientFormer.

The detection and segmentation pipeline is from MMCV (MMDetection and MMSegmentation).

Thanks for the great implementations!

Citation

If our code or models help your work, please cite our paper:

@misc{wang2023repvit,
      title={RepViT: Revisiting Mobile CNN From ViT Perspective}, 
      author={Ao Wang and Hui Chen and Zijia Lin and Jungong Han and Guiguang Ding},
      year={2023},
      eprint={2307.09283},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

@misc{wang2023repvitsam,
      title={RepViT-SAM: Towards Real-Time Segmenting Anything}, 
      author={Ao Wang and Hui Chen and Zijia Lin and Jungong Han and Guiguang Ding},
      year={2023},
      eprint={2312.05760},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.