Git Product home page Git Product logo

cpcv2-pytorch's Introduction

Contrastive Predictive Coding

PyTorch implementation of the following papers:

A. v. d. Oord, Y. Li, and O. Vinyals, Representation learning with contrastive predictive coding

O. J. H ́enaff, A. Srinivas, J. D. Fauw, A. Razavi, C. Doersch, S. M. A. Eslami, and A. van den Oord Data-Efficient Image Recognition with Contrastive Predictive Coding

Dependencies

  • PyTorch (verified with version 1.6.0)
  • tqdm
  • numpy
  • opencv-python (patch aug visualistaion only - not training CPC)

Included is environment.yml

Usage

There are two training functions, one for the unsupervised training and one for supervised training.

  • Viewing all command-line options
    python train_classifier.py -h
    
    python train_CPC.py -h
    
  • Training a fully supervised model
    python train_classifier.py --fully_supervised --dataset stl10 --encoder resnet18
    
  • Training Resnet14 on STL10 with CPCV1 - Unsupervised Stage
    python  train_CPC.py --dataset stl10 --epochs 300 --crop 64-0 --encoder resnet14 --norm none --grid_size 7 --pred_steps 5 --pred_directions 1
    
  • Training Wideresnet-28-2 on CIFAR10 with CPCV2 (and a smaller grid size) - Unsupervised Stage
    python train_CPC.py --dataset  cifar10 --epochs 500 --crop 30-2 --encoder wideresnet-28-2 --norm layer --grid_size 5 --pred_steps 3 --pred_directions 4 --patch_aug 
    
  • Training Wideresnet-28-2 on CIFAR10 with CPCV2 (and a smaller grid size) - Supervised Stage with 10,000 labeled images
    python train_classifier.py --dataset cifar10 --train_size 10000 --epochs 100 --lr 0.1 --crop 30-2 --encoder wideresnet-28-2 --norm layer --grid_size 5 --pred_directions 4 --cpc_patch_aug --patch_aug --model_num 500    
    

Usage on Euler

Clone this repository on Euler and place the required training datasets in the data folder.

🆕 You can now specify the size of unsupervised dataset for training CPC. Pass the --unsupervised_size 10000 argument in train_CPC.py to only use 10000 images. Use the --t1 {color, rotate, cutout, crop} and --t2 {color, rotate, cutout, crop} flags to specify two augmentations applied one after the other on a patchwise level.

  • Submitting an interactive job on Euler with GPU
    srun --gpus=1 --gres=gpumem:12g --ntasks=2 --mem-per-cpu=12G --pty python train_CPC.py --dataset stl10 --unsupervised_size 10000 --epochs 100 --crop 64-0 --encoder resnet14 --norm none --grid_size 7 --pred_steps 5 --pred_directions 1
    
  • Submitting a batch job on Euler with GPU
    sbatch --time=4:00:00 --gpus=1 --gres=gpumem:12g --ntasks=2 --mem-per-cpu=12G --wrap="python train_CPC.py --dataset stl10 --unsupervised_size 10000 --epochs 100 --crop 64-0 --encoder resnet14 --norm none --grid_size 7 --pred_steps 5 --pred_directions 1"
    

cpcv2-pytorch's People

Contributors

rschwarz15 avatar vanditsharma02 avatar rschwarz2015 avatar abby3017 avatar lraud avatar sachmatkris avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.