Git Product home page Git Product logo

cadel's Introduction

Long-tailed Classification via CAscaded Deep Ensemble Learning

Zhi Chen, Jiang Duan, Yu Xiong, Cheng Yang and Guoping Qiu

This repository is the official PyTorch implementation of the paper CADEL.

Environments

pytorch >= 1.8.0
timm == 0.3.2
  1. If your PyTorch is 1.8.0+, a fix is needed to work with timm.

Data preparation

You can download the original datasets as follows:

Change the data_root in main_CNN.py and main_ViT.py accordingly.

After preparation, the file structures are as follows:

/path/to/ImageNet-LT/
    train/
        class1/
            img1.jpeg
        class2/
            img2.jpeg
    val/
        class1/
            img3.jpeg
        class2/
            img4.jpeg
    train.txt
    val.txt
    test.txt
    num_shots.txt

train.txt, val.txt and test.txt list the file names, and num_shots.txt gives the number of training images in each class. All these data files have been uploaded to this repo.

Usage

  1. You can see all our settings in ./config/

  2. Typically, 2 GPUs and >=24 GB per GPU Memory are available. But when training ViT-B-16 with a training resolution of 384, bigger GPU Memory is required.

For the stage one training, you can train the model with DataParallel or DistributedDataParallel. Specially, for stage one training, the commands are:

# Stage one
python main_CNN.py (or python main_ViT.py if you want to train ViT)
or
torch.distributed.launch --nproc_per_node=n main_CNN.py

where n is the number of gpus in your server. And you should divide the defaulting batch_size in our configs with n.

# Stage Two
python main_CNN_PC.py

Results of CNNs

Datasets Many Medium Few All Model
ImageNet-LT 67.5 55.6 43.2 58.5 ResNet50
ImageNet-LT 68.8 55.8 44.0 59.2 ResNeXt50
iNat18 --- --- --- 73.5 ResNet50
Places-LT --- --- --- 41.4 ResNet152

Results of ViTs

Dataset Resolution Many Med. Few Acc Pretrain ckpt
ImageNet-LT 224*224 70.3 59.8 47.5 61.7 Res_224
ImageNet-LT 384*384 73.0 62.2 50.3 64.7
iNat18 224*224 77.7 76.3 75.1 76.2 Res_128
iNat18 384*384 75.0 81.8 85.4 82.7
Places-LT 224*224 46.6 46.7 46.5 46.6 Image-1K-224
Places-LT 384*384 47.9 50.2 38.5 47.1 Image-1K-384

Citation

If you find our idea or code inspiring, please cite our paper:

@article{CADEL,
  title={Long-tailed Classification via CAscaded Deep Ensemble Learning},
  author={Zhi Chen, Jiang Duan, Yu Xiong, Cheng Yang and Guoping Qiu},
  year={2023},
  archivePrefix={arXiv},
  primaryClass={cs.AI}
}

This code is partially based on cRT and LiVT, if you use our code, please also cite:

@inproceedings{kang2019decoupling,
  title={Decoupling representation and classifier for long-tailed recognition},
  author={Kang, Bingyi and Xie, Saining and Rohrbach, Marcus and Yan, Zhicheng
          and Gordo, Albert and Feng, Jiashi and Kalantidis, Yannis},
  booktitle={Eighth International Conference on Learning Representations (ICLR)},
  year={2020}
}
@inproceedings{LiVT,
  title={Learning Imbalanced Data with Vision Transformers},
  author={Xu, Zhengzhuo and Liu, Ruikang and Yang, Shuo and Chai, Zenghao and Yuan, Chun},
  booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2023}
}

Acknowledgements

This project is based on cRT and LiVT.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.