Git Product home page Git Product logo

simclr-pytorch's Introduction

Unofficial Pytorch Implementation of SimCLR

Unofficial Pytorch implemnentation of SimCLR: "A Simple Framework for Contrastive Learning of Visual Representations"

Work in progress, replicating results on ImageNet, TinyImageNet, CIFAR, STL10.

Results:

Top-1 Acc / Error of linear evaluation on CIFAR10:

Testing is performed on the CIFAR10 Val set, whilst the Train set is split into Train and Val for tuning.

Method Batch Size ResNet Projection Head Dim. Pre-train Epochs Pretrain Optim Finetune Optim Finetune LR Eval Epochs Acc(%)
SimCLR + Linear eval. 512 ResNet18 128 1000 LARS SGD 0.01 100 88.53
SimCLR + Linear eval. 512 ResNet18 128 1000 LARS LARS 0.1 100 90.47
SimCLR + Linear eval. 512 ResNet50 128 1000 LARS SGD 0.01 100 92.75
SimCLR + Linear eval. 512 ResNet50 128 1000 LARS LARS 0.1 100 92.45
SimCLR + Linear eval. 512 ResNet101 128 1000 LARS SGD 0.01 100 93.98
SimCLR + Linear eval. 512 ResNet101 128 1000 LARS LARS 0.1 100 93.42
Supervised + Linear eval. 512 ResNet18 128 1000 LARS LARS 0.1 100 93.63
Random Init + Linear eval. 512 ResNet18 128 1000 LARS LARS 0.1 100 26.37

Note: For Linear Evaluation the ResNet is frozen (all layers), training is only perfomed on the supervised Linear Evaluation layer.

Plots:

ResNet-18

ResNet-50

Usage / Run

Contrastive Training and Linear Evaluation

Launch the script from src/main.py:

By default the CIFAR-10 dataset is used, use --dataset to select from: cifar10, cifar100, stl10, imagenet, tinyimagenet. For ImageNet and TinyImageNet please define a path to the dataset.

Training uses CosineAnnealingLR decay and linear warmup as stated in https://github.com/google-research/simclr/blob/master/lars_optimizer.py. The learning_rate is plotted below:

DistributedDataParallel

To train with Distributed for a slight computational speedup with multiple GPUs, use:

python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=2 --use_env main.py

This will train on a single machine (nnodes=1), assigning 1 process per GPU where nproc_per_node=2 refers to training on 2 GPUs. To train on N GPUs simply launch N processes by setting nproc_per_node=N.

The number of CPU threads to use per process is hard coded to torch.set_num_threads(1) for safety, and can be changed to your # cpu threads / nproc_per_node for better performance. (fabio-deep)

For more info on multi-node and multi-gpu distributed training refer to https://github.com/hgrover/pytorchdistr/blob/master/README.md

DataParallel

To train with traditional nn.DataParallel with multiple GPUs, use:

python main.py --no_distributed

Note: The default config selects to use --no_distributed, therefore runnning python main.py runs the default hyperparameters without DistributedDataParallel.

Linear Evaluation of a Pre-Trained Model

To evaluate the performace of a pre-trained model in a linear classification task just include the flag --finetune and provide a path to the pretrained model to --load_checkpoint_dir.

Example:

python main.py --no_distributed --finetune --load_checkpoint_dir ~/Documents/SimCLR-Pytorch/experiments/yyyy-mm-dd_hh-mm-ss/checkpoint.pt

Hyperparameters

The configuration / choice of hyperparameters for the script is handled either by command line arguments or config files.

An example config file is given at SimCLR-Pytorch/config.conf. Additionally, .txt or .conf files can be passed if you prefer, this is achieved using the flag --c <PATH/TO/CONFIG.conf>.

A list of arguments/options can be found below:

usage: main.py [-h] [-c MY_CONFIG] [--dataset DATASET]
               [--dataset_path DATASET_PATH] [--model MODEL]
               [--n_epochs N_EPOCHS] [--finetune_epochs FINETUNE_EPOCHS]
               [--warmup_epochs WARMUP_EPOCHS] [--batch_size BATCH_SIZE]
               [--learning_rate LEARNING_RATE]
               [--finetune_learning_rate FINETUNE_LEARNING_RATE]
               [--weight_decay WEIGHT_DECAY]
               [--finetune_weight_decay FINETUNE_WEIGHT_DECAY]
               [--optimiser OPTIMISER]
               [--finetune_optimiser FINETUNE_OPTIMISER] [--patience PATIENCE]
               [--temperature TEMPERATURE] [--jitter_d JITTER_D]
               [--jitter_p JITTER_P] [--blur_sigma BLUR_SIGMA BLUR_SIGMA]
               [--blur_p BLUR_P] [--grey_p GREY_P] [--no_twocrop]
               [--load_checkpoint_dir LOAD_CHECKPOINT_DIR] [--no_distributed]
               [--finetune] [--supervised]

Pytorch SimCLR Args that start with '--' (eg. --dataset) can also be set in a
config file (<PATH>/SimCLR-Pytorch/config.conf or specified via -c).
Config file syntax allows: key=value, flag=true, stuff=[a,b,c] (for details,
see syntax at https://goo.gl/R74nmi). If an arg is specified in more than one
place, then commandline values override config file values which override
defaults.

optional arguments:
  -h, --help            show this help message and exit
  -c MY_CONFIG, --my-config MY_CONFIG
                        config file path
  --dataset DATASET     Dataset, (Options: cifar10, cifar100, stl10, imagenet,
                        tinyimagenet).
  --dataset_path DATASET_PATH
                        Path to dataset, Not needed for TorchVision Datasets.
  --model MODEL         Model, (Options: resnet18, resnet34, resnet50,
                        resnet101, resnet152).
  --n_epochs N_EPOCHS   Number of Epochs in Contrastive Training.
  --finetune_epochs FINETUNE_EPOCHS
                        Number of Epochs in Linear Classification Training.
  --warmup_epochs WARMUP_EPOCHS
                        Number of Warmup Epochs During Contrastive Training.
  --batch_size BATCH_SIZE
                        Number of Samples Per Batch.
  --learning_rate LEARNING_RATE
                        Starting Learing Rate for Contrastive Training.
  --finetune_learning_rate FINETUNE_LEARNING_RATE
                        Starting Learing Rate for Linear Classification
                        Training.
  --weight_decay WEIGHT_DECAY
                        Contrastive Learning Weight Decay Regularisation
                        Factor.
  --finetune_weight_decay FINETUNE_WEIGHT_DECAY
                        Linear Classification Training Weight Decay
                        Regularisation Factor.
  --optimiser OPTIMISER
                        Optimiser, (Options: sgd, adam, lars).
  --finetune_optimiser FINETUNE_OPTIMISER
                        Finetune Optimiser, (Options: sgd, adam, lars).
  --patience PATIENCE   Number of Epochs to Wait for Improvement.
  --temperature TEMPERATURE
                        NT_Xent Temperature Factor
  --jitter_d JITTER_D   Distortion Factor for the Random Colour Jitter
                        Augmentation
  --jitter_p JITTER_P   Probability to Apply Random Colour Jitter Augmentation
  --blur_sigma BLUR_SIGMA BLUR_SIGMA
                        Radius to Apply Random Colour Jitter Augmentation
  --blur_p BLUR_P       Probability to Apply Gaussian Blur Augmentation
  --grey_p GREY_P       Probability to Apply Random Grey Scale
  --no_twocrop          Whether or Not to Use Two Crop Augmentation, Used to
                        Create Two Views of the Input for Contrastive
                        Learning. (Default: True)
  --load_checkpoint_dir LOAD_CHECKPOINT_DIR
                        Path to Load Pre-trained Model From.
  --no_distributed      Whether or Not to Use Distributed Training. (Default:
                        True)
  --finetune            Perform Only Linear Classification Training. (Default:
                        False)
  --supervised          Perform Supervised Pre-Training. (Default: False)

Dependencies

Install dependencies with requrements.txt

pip install -r requirements.txt

torch
torchvision
tensorboard
tqdm
configargparse

References

simclr-pytorch's People

Contributors

aidendurrant avatar mark-arm avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.