Git Product home page Git Product logo

cms's Introduction

Contrastive Mean-Shift Learning for Generalized Category Discovery

Pohang University of Science and Technology (POSTECH)



result

Environmnet installation

This project is built upon the following environment:

The package requirements can be installed via requirements.txt,

pip install -r requirements.txt

Datasets

We use fine-grained benchmarks in this paper, including:

We also use generic object recognition datasets, including:

Please follow this repo to set up the data.

Download the datasets, ssb splits, and pretrained backbone by following the file structure below and set DATASET_ROOT={YOUR DIRECTORY} in config.py.

    DATASET_ROOT/
    ├── cifar100/
    │   ├── cifar-100-python\
    │   │   ├── meta/
    │       ├── ...
    ├── CUB_200_2011/
    │   ├── attributes/
    │   ├── ...
    ├── ...
    CMS/
    ├── data/
    │   ├── ssb_splits/
    ├── models/
    │   ├── dino_vitbase16_pretrain.pth
    ├── ...

Training

bash bash_scripts/contrastive_meanshift_training.sh

Example bash commands for training are as follows:

# GCD
python -m methods.contrastive_meanshift_training  \
            --dataset_name 'cub' \
            --lr 0.05 \
            --temperature 0.25 \
            --wandb 

# Inductive GCD
python -m methods.contrastive_meanshift_training  \
            --dataset_name 'cub' \
            --lr 0.05 \
            --temperature 0.25 \
            --inductive \
            --wandb 

Evaluation

bash bash_scripts/meanshift_clustering.sh

Example bash command for evaluation is as follows. It will require changing model_name.

python -m methods.meanshift_clustering \
        --dataset_name 'cub' \
        --model_name 'cub_best' \

Results and checkpoints

Experimental results on GCD task.

All Old Novel Checkpoints
CIFAR100 82.3 85.7 75.5 link
ImageNet100 84.7 95.6 79.2 link
CUB 68.2 76.5 64.0 link
Stanford Cars 56.9 76.1 47.6 link
FGVC-Aircraft 56.0 63.4 52.3 link
Herbarium19 36.4 54.9 26.4 link

Experimental results on inductive GCD task.

All Old Novel Checkpoints
CIFAR100 80.7 84.4 65.9 link
ImageNet100 85.7 95.7 75.8 link
CUB 69.7 76.5 63.0 link
Stanford Cars 57.8 75.2 41.0 link
FGVC-Aircraft 53.3 62.7 43.8 link
Herbarium19 46.2 53.0 38.9 link

Citation

If you find our code or paper useful, please consider citing our paper:

  @inproceedings{choi2024contrastive,
    title={Contrastive Mean-Shift Learning for Generalized Category Discovery},
    author={Choi, Sua and Kang, Dahyun and Cho, Minsu},
    booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
    year={2024}
  }

Related Repos

The codebase is largely built on Generalized Category Discovery and PromptCAL.

Acknowledgements

This work was supported by the NRF grant (NRF-2021R1A2C3012728 (50%)) and the IITP grants (2022-0-00113: Developing a Sustainable Collaborative Multi-modal Lifelong Learning Framework (45%), 2019-0-01906: AI Graduate School Program at POSTECH (5%)) funded by Ministry of Science and ICT, Korea.

cms's People

Contributors

dahyun-kang avatar sua-choi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

dahyun-kang

cms's Issues

question of aircraft dataset

Excuse me, i download the provided weight of aircraft dataset, but the result is

Namespace(batch_size=128, num_workers=8, pretrain_path='./models', transform='imagenet', eval_funcs=['v2'], use_ssb_splits=True, model_name='aircraft_p', dataset_name='aircraft', epochs=20, feat_dim=768, num_clusters=None, inductive=False, k=8, alpha=0.5, image_size=224, train_classes=[0, 1, 2, 3, 4, 5, 10, 11, 14, 16, 17, 19, 21, 22, 23, 24, 27, 28, 29, 30, 33, 36, 37, 38, 39, 41, 43, 44, 45, 46, 47, 48, 52, 53, 56, 57, 58, 63, 64, 65, 66, 67, 71, 73, 76, 77, 79, 92, 95, 99], unlabeled_classes=[6, 7, 8, 9, 12, 13, 31, 32, 25, 26, 18, 20, 15, 78, 82, 51, 49, 50, 54, 55, 59, 60, 61, 68, 69, 70, 85, 86, 87, 88, 80, 81, 42, 84, 40, 90, 74, 75, 97, 98, 34, 35, 93, 94, 96, 72, 91, 83, 62, 89], num_labeled_classes=50, num_unlabeled_classes=50)
Using weights from /home/GEN/code/CMS/log/cms/log/aircraft_p/checkpoints/model_best.pt ...
/home/GEN/miniconda3/envs/torch1/lib/python3.10/site-packages/torchvision/transforms/transforms.py:329: UserWarning: Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
  warnings.warn(
Predicted number of clusters:  92
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:34<00:00,  1.54it/s]
num clusters 100
Epoch 0, IMS unlabeled train ACC_v2: All 0.5353 | Old 0.6158 | New 0.4951
num clusters 92
Epoch 0, IMS unlabeled train ACC_v2: All 0.5373 | Old 0.6176 | New 0.4972
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:19<00:00,  2.74it/s]
num clusters 100
Epoch 1, IMS unlabeled train ACC_v2: All 0.5343 | Old 0.6170 | New 0.4930
num clusters 92
Epoch 1, IMS unlabeled train ACC_v2: All 0.5433 | Old 0.5984 | New 0.5157
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:19<00:00,  2.71it/s]
num clusters 100
Epoch 2, IMS unlabeled train ACC_v2: All 0.5393 | Old 0.6164 | New 0.5007
num clusters 92
Epoch 2, IMS unlabeled train ACC_v2: All 0.5461 | Old 0.5936 | New 0.5223
ACC with GT number of clusters: All 0.5393 | Old 0.6164 | New 0.5007
ACC with predicted number of clusters: All 0.5461 | Old 0.5936 | New 0.5223

the number is smaller than README (53 vs. 56).

And I also try to train the model by myself, but the number is also small than paper(CUB 66, SCARS 48, AIRCRAFT 51). I use the default hyperparameters on github, and no change to codes. I think maybe there are something wrong in my experiment, do you have any idea can help me fix the problem?

I look forward to your response at your earliest convenience.

About Clustering without the gt number of K

Thank you for the great work on GCD.
I would like to ask how to train the model on cub without the real k value. Which part of the code should I change?
Looking forward to your response and would like to thank you once again for your great work !

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.