Git Product home page Git Product logo

caco's Introduction

CaCo

CaCo is a contrastive-learning based self-supervised learning methods, which is submitted to IEEE-T-PAMI.

Copyright (C) 2020 Xiao Wang, Yuhang Huang, Dan Zeng, Guo-Jun Qi

License: MIT for academic use.

Contact: Xiao Wang ([email protected]), Guo-Jun Qi ([email protected])

Introduction

As a representative self-supervised method, contrastive learning has achieved great successes in unsupervised training of representations. It trains an encoder by distinguishing positive samples from negative ones given query anchors. These positive and negative samples play critical roles in defining the objective to learn the discriminative encoder, avoiding it from learning trivial features. While existing methods heuristically choose these samples, we present a principled method where both positive and negative samples are directly learnable end-to-end with the encoder. We show that the positive and negative samples can be cooperatively and adversarially learned by minimizing and maximizing the contrastive loss, respectively. This yields cooperative positives and adversarial negatives with respect to the encoder, which are updated to continuously track the learned representation of the query anchors over mini-batches. The proposed method achieves 72.0% and 75.3% in top-1 accuracy respectively over 200 and 800 epochs of pre-training ResNet-50 backbone on ImageNet1K without tricks such as multi-crop or stronger augmentations. With Multi-Crop, it can be further boosted into 75.7%.

Installation

CUDA version should be 10.1 or higher.

1. Install git

2. Clone the repository in your computer

git clone [email protected]:maple-research-lab/CaCo.git && cd CaCo

3. Build dependencies.

You have two options to install dependency on your computer:

3.1 Install with pip and python(Ver 3.6.9).

3.1.1install pip.
3.1.2 Install dependency in command line.
pip install -r requirements.txt --user

If you encounter any errors, you can install each library one by one:

pip install torch>=1.7.1
pip install torchvision>=0.8.2
pip install numpy>=1.19.5
pip install Pillow>=5.1.0
pip install tensorboard>=1.14.0
pip install tensorboardX>=1.7

3.2 Install with anaconda

3.2.1 install conda.
3.2.2 Install dependency in command line
conda create -n CaCo python=3.7.1
conda activate CaCo
pip install -r requirements.txt 

Each time when you want to run my code, simply activate the environment by

conda activate CaCo
conda deactivate(If you want to exit) 

4 Prepare the ImageNet dataset

4.1 Download the ImageNet2012 Dataset under "./datasets/imagenet2012".
4.2 Go to path "./datasets/imagenet2012/val"
4.3 move validation images to labeled subfolders, using the following shell script

Usage

1. Single-Crop Unsupervised Pre-Training

1.1 Training with batch size of 1024 (Single Machine)

For batch-size of 1024, we can run on a single machine of 8*V100 32gb GPU with the following command:

python3 main.py --type=0 --lr=0.3 --lr_final=0.003 --memory_lr=3.0 --memory_lr_final=3.0 --cluster=65536 --moco_t=0.08 --mem_t=0.08 --data=datasets/imagenet --dist_url=tcp://localhost:10001 --batch_size=1024 --wd=1.5e-6 --mem_wd=0 --moco_dim=256 --moco_m=0.99 --moco_m_decay=1 --mlp_dim=2048 --epochs=200 --warmup_epochs=10 --nodes_num=1 --workers=32 --world_size 1 --rank=0 --mem_momentum=0.9 --ad_init=1 --knn_batch_size=1024 --multi_crop=0 --knn_freq=10

This should be able to reproduce our 71.3% performance with batch size 1024.

1.2 Training with batch size of 4096 (4 Machines)

This can only run with multiple machines. Limited by our computing resources, we run experiments with 2048 on 4 8*V100 GPU matchines On the first node machine, run the following command:

python3 main.py --type=0 --lr=0.3 --lr_final=0.003 --memory_lr=3.0 --memory_lr_final=3.0 --cluster=65536 --moco_t=0.08 --mem_t=0.08 --data=datasets/imagenet --dist_url=tcp://localhost:10001 --batch_size=4096 --wd=1.5e-6 --mem_wd=0 --moco_dim=256 --moco_m=0.99 --moco_m_decay=1 --mlp_dim=2048 --epochs=200 --warmup_epochs=10 --nodes_num=1 --workers=128 --world_size 4 --rank=0 --mem_momentum=0.9 --ad_init=1 --knn_batch_size=1024 --multi_crop=0 --knn_freq=20

Then iteratively run on other nodes with the following command:

python3 main.py --type=0 --lr=0.3 --lr_final=0.003 --memory_lr=3.0 --memory_lr_final=3.0 --cluster=65536 --moco_t=0.08 --mem_t=0.08 --data=datasets/imagenet --dist_url=tcp://[master_id]:10001 --batch_size=4096 --wd=1.5e-6 --mem_wd=0 --moco_dim=256 --moco_m=0.99 --moco_m_decay=1 --mlp_dim=2048 --epochs=200 --warmup_epochs=10 --nodes_num=1 --workers=128 --world_size 4 --rank=[rank_id] --mem_momentum=0.9 --ad_init=1 --knn_batch_size=1024 --multi_crop=0 --knn_freq=20

Here we should change [master_ip] to the IP of the 1st node, also we should adjust rank with 1, 2, and 3 for 3 different nodes.

2. Multi-Crop Unsupervised Pre-Training (4 Machines)

This can only be run with multiple machines. Limited by our computing resources, we run experiments with 2048 on 4 8*V100 GPU matchines On the first node machine, run the following command:

python3 main.py --type=0 --lr=0.3 --lr_final=0.003 --memory_lr=3.0 --memory_lr_final=3.0 --cluster=65536 --moco_t=0.08 --mem_t=0.08 --data=datasets/imagenet --dist_url=tcp://localhost:10001 --batch_size=2048 --wd=1.5e-6 --mem_wd=0 --moco_dim=256 --moco_m=0.99 --moco_m_decay=1 --mlp_dim=2048 --epochs=800 --warmup_epochs=10 --nodes_num=4 --workers=128 --world_size 4 --rank=0 --mem_momentum=0.9 --ad_init=1 --knn_batch_size=2048 --multi_crop=1 --knn_freq=50

Then iteratively run on other nodes with the following command:

python3 main.py --type=0 --lr=0.3 --lr_final=0.003 --memory_lr=3.0 --memory_lr_final=3.0 --cluster=65536 --moco_t=0.08 --mem_t=0.08 --data=datasets/imagenet --dist_url=tcp://[master_id]:10001 --batch_size=2048 --wd=1.5e-6 --mem_wd=0 --moco_dim=256 --moco_m=0.99 --moco_m_decay=1 --mlp_dim=2048 --epochs=800 --warmup_epochs=10 --nodes_num=4 --workers=128 --world_size 4 --rank=[rank_id] --mem_momentum=0.9 --ad_init=1 --knn_batch_size=2048 --multi_crop=1 --knn_freq=50

Here we should change [master_ip] to the IP of the 1st node, also we should adjust rank with 1, 2, and 3 for 3 different nodes.
We believe further increase the batch size to 4096 can increase the performance.

Linear Classification

With a pre-trained model, we can easily evaluate its performance on ImageNet with:

python linear.py  -a resnet50 --lr 0.025 --batch-size 4096 \
  --pretrained [your checkpoint path] \
  --dist-url 'tcp://localhost:10001' --multiprocessing-distributed \
  --world-size 1 --rank 0 --data [imagenet path]

Linear Performance:

pre-train
network
pre-train
epochs
Crop Batch
Size
CaCo
top-1 acc.
Model
Link
ResNet-50 200 Single 1024 71.3 model
ResNet-50 200 Single 4096 72.0 model
ResNet-50 800 Single 4096 75.3 None
ResNet-50 800 Multi 2048 75.7 model

Citation:

CaCo: Both Positive and Negative Samples are Directly Learnable via Cooperative-adversarial Contrastive Learning .

@article{wang2022caco,
  title={CaCo: Both Positive and Negative Samples are Directly Learnable via Cooperative-adversarial Contrastive Learning },
  author={Wang, Xiao and Huang, Yuhang and Zeng, Dan and Qi, Guo-Jun},
  journal={IEEE Transactions on Pattern Analysis and Machine Intelligence (submitted)},
  year={2022}
}

caco's People

Contributors

wang3702 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.