Git Product home page Git Product logo

ffvt's Introduction

Feature Fusion Vision Transformer for Fine-Grained Visual Categorization

Official PyTorch implementation of Feature Fusion Vision Transformer for Fine-Grained Visual Categorization (BMVC 2021).

If you use the code in this repo for your work, please cite the following bib entries:

@article{wang2021feature,
  title={Feature Fusion Vision Transformer for Fine-Grained Visual Categorization},
  author={Wang, Jun and Yu, Xiaohan and Gao, Yongsheng},
  journal={British Machine Vision Conference},
  year={2021}
}

Abstract

The core for tackling the fine-grained visual categorization (FGVC) is to learn subtleyet discriminative features. Most previous works achieve this by explicitly selecting thediscriminative parts or integrating the attention mechanism via CNN-based approaches.However, these methods enhance the computational complexity and make the modeldominated by the regions containing the most of the objects. Recently, vision trans-former (ViT) has achieved SOTA performance on general image recognition tasks. Theself-attention mechanism aggregates and weights the information from all patches to theclassification token, making it perfectly suitable for FGVC. Nonetheless, the classifi-cation token in the deep layer pays more attention to the global information, lackingthe local and low-level features that are essential for FGVC. In this work, we proposea novel pure transformer-based framework Feature Fusion Vision Transformer (FFVT)where we aggregate the important tokens from each transformer layer to compensate thelocal, low-level and middle-level information. We design a novel token selection mod-ule called mutual attention weight selection (MAWS) to guide the network effectivelyand efficiently towards selecting discriminative tokens without introducing extra param-eters. We verify the effectiveness of FFVT on four benchmarks where FFVT achievesthe state-of-the-art performance.

Prerequisites

The following packages are required to run the scripts:

  • [Python >= 3.6]
  • [PyTorch = 1.5]
  • [Torchvision]
  • [Apex]

Download Google pre-trained ViT models

wget https://storage.googleapis.com/vit_models/imagenet21k/{MODEL_NAME}.npz

Dataset

You can download the datasets from the links below:

Training scripts for FFVT on Cotton dataset.

Train the model on the Cotton dataset. We run our experiments on 4x2080Ti/4x1080Ti with the batchsize of 8 for each card.

$ python3 -m torch.distributed.launch --nproc_per_node 4 train.py --name {name} --dataset cotton --model_type ViT-B_16 --pretrained_dir {pretrained_model_dir} --img_size 384 --resize_size 500 --train_batch_size 8 --learning_rate 0.02 --num_steps 2000 --fp16 --eval_every 16 --feature_fusion

Training scripts for FFVT on Soy.Loc dataset.

Train the model on the SoyLoc dataset. We run our experiments on 4x2080Ti/4x1080Ti with the batchsize of 8 for each card.

$ python3 -m torch.distributed.launch --nproc_per_node 4 train.py --name {name} --dataset soyloc --model_type ViT-B_16 --pretrained_dir {pretrained_model_dir} --img_size 384 --resize_size 500 --train_batch_size 8 --learning_rate 0.02 --num_steps 4750 --fp16 --eval_every 50 --feature_fusion

Training scripts for FFVT on CUB dataset.

Train the model on the CUB dataset. We run our experiments on 4x2080Ti/4x1080Ti with the batchsize of 8 for each card.

$ python3 -m torch.distributed.launch --nproc_per_node 4 train.py --name {name} --dataset CUB --model_type ViT-B_16 --pretrained_dir {pretrained_model_dir} --img_size 448 --resize_size 600 --train_batch_size 8 --learning_rate 0.02 --num_steps 10000 --fp16 --eval_every 200 --feature_fusion

Training scripts for FFVT on Dogs dataset.

Train the model on the Dog dataset. We run our experiments on 4x2080Ti/4x1080Ti with the batchsize of 4 for each card.

$ python3 -m torch.distributed.launch --nproc_per_node 4 train.py --name {name} --dataset CUB --model_type ViT-B_16 --pretrained_dir {pretrained_model_dir} --img_size 448 --resize_size 600 --train_batch_size 4 --learning_rate 3e-3 --num_steps 30000 --fp16 --eval_every 300 --feature_fusion --decay_type linear --num_token 24

Download Models

Trained model Google Drive

Acknowledgment

Thanks for the advice and guidance given by Dr.Xiaohan Yu and Prof. Yongsheng Gao.

Our project references the codes in the following repos. Thanks for thier works and sharing.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.