Git Product home page Git Product logo

pytorch-byol's Introduction

PyTorch-BYOL

Image of Yaktocat

Installation

Clone the repository and run

$ conda env create --name byol --file env.yml
$ conda activate byol
$ python main.py

Config

Before running PyTorch BYOL, make sure you choose the correct running configurations on the config.yaml file.

network:
  name: resnet18 # base encoder. choose one of resnet18 or resnet50
   
  # Specify a folder containing a pre-trained model to fine-tune. If training from scratch, pass None.
  fine_tune_from: 'resnet-18_40-epochs'
   
  # configurations for the projection and prediction heads
  projection_head: 
    mlp_hidden_size: 512 # Original implementation uses 4096
    projection_size: 128 # Original implementation uses 256

data_transforms:
  s: 1
  input_shape: (96,96,3)

trainer:
  batch_size: 64 # Original implementation uses 4096
  m: 0.996 # momentum update
  checkpoint_interval: 5000
  max_epochs: 40 # Original implementation uses 1000
  num_workers: 4 # number of worker for the data loader

optimizer:
  params:
    lr: 0.03
    momentum: 0.9
    weight_decay: 0.0004

Feature Evaluation

We measure the quality of the learned representations by linear separability.

During training, BYOL learns features using the STL10 train+unsupervised set and evaluates in the held-out test set.

Linear Classifier Feature Extractor Architecture Feature dim Projection Head dim Epochs Batch Size STL10 Top 1
Logistic Regression PCA Features - 256 - - 36.0%
KNN PCA Features - 256 - - 31.8%
Logistic Regression (Adam) BYOL (SGD) ResNet-18 512 128 40 64 70.1%
Logistic Regression (Adam) BYOL (SGD) ResNet-18 512 128 80 64 75.2%

pytorch-byol's People

Contributors

sthalles avatar tarunn2799 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-byol's Issues

There should be two MLPS behind the online network and one behind the target network

First of all, thank you so much for making this work open source!

There should be two MLP behind the online network in the paper, and there should be one MLP behind the target network in the paper.
And in the code, online network has one MLP. But target network has no MLP. Besides, There are also parameter updates between the MLP of the online network and the target network. Is there an oversight here? Or I didn't see it

Representation Collapse

First of all, thank you so much for making this work open source!

My question isn't directly related to your implementation but rather a question about the paper -- I didn't know where else to ask, and I figured you'd have a pretty good understanding of the paper. Having said that, feel free to mark it closed if you think it is inappropriate.

I recently read the paper and I don't understand why the network doesn't cheat by simply learning to output 0s, or in their own words, by learning "collapsed representations." Any ideas?

Need advice using custom dataset

Hi this is really good application. I would like to have your advice to use my own dataset. I have images that are very different with widely known dataset such as imagenet containing common object (cars, cat, dog, etc). this is very technical related, it is dynamometer cards that commonly used in oil industry as shown below :
https://knepublishing.com/index.php/KnE-Engineering/article/download/3083/6588/15176
Unlike common object, my dataset always comes in neutral position, there is no rotated, no flipped, no color distortion (just B&W) and always comes in the full part (unlike cars that might exist only part of it in images), etc that might exist during implementation in the real world. So in the common object recognition, I assume such data augmentations are needed to introduce many variations to the model.
Do I need to do same data augmentations used by default in this repo ? or is it ok for BYOL to use no augmentations or minimum augmentations ? what do you think the best solution needed for my own dataset especially in term of augmentation.

thank you

About STL10 top-1 accuracy

Hi, thanks for your share. But I have a question about your STL10 top-1 accuracy. Is the accuracy '70.1' fine-tuning from your pre-trained model 'resnet-18_40-epochs' or training from scratch for 40 epochs?

Training on CIFAR10

Hello,

Thank you for this excellent repository!

Do you have any suggestions of changes to make to train BYOL on the CIFAR10 dataset?

The way I am doing this (in main.py) (I am also training my own custom models, but I do not think that is too relevant)

DATASET='CIFAR10' # Can change to STL10

if DATASET=='STL10':
    train_dataset = datasets.STL10('/workspace/STLDataset', split='train+unlabeled', download=True,
                                    transform=MultiViewDataInjector([data_transform, data_transform]))
elif DATASET=='CIFAR10':
    train_dataset = datasets.CIFAR10('/workspace/CIFAR10Dataset', train=True, download=True,
                                    transform=MultiViewDataInjector([data_transform, data_transform]))
else:
    print("Error, dataset not supported, choose CIFAR10 or STL10")
    exit(0)

I also change the config to have: input_shape: (32,32,3).

Further, I may not have taken a very deep look into this code-base, but how do we produce the 'STL10 Top 1' accuracies(75.2%) after training the model on the self-supervised task? Do we take the trained model and fine-tune on the STL10 supervised dataset? I assume that code is not included in this library?

Thank you!

fine-tuning problem

there is only Linear evaluation but no fine-tuning which may lead to a better result

training on CIFAR-10

Hello, Thank you very much for the program you shared, but I want to test it on the CIFAR-10 data set. How should the relevant parameters be set?

loss is a neg value?

when I print loss value in training phase, Found that loss is a neg value after epoch 40(about -3.45), is that right? confused~

Unable to replicate experiments for 80 epochs

Thanks a ton for making this open-source. I have a question about replicating the 75% accuracy for 80 epochs. I just changed the config file to read "max_epochs: 80" and then trained the model, but it only reached roughly 70% accuracy, see below. Is there anything else that needs to be changed to reach the 75% accuracy you report? Thanks!!

Testing accuracy: 69.5125
Testing accuracy: 69.5125
Testing accuracy: 69.6875
Testing accuracy: 69.9125
Testing accuracy: 69.7875
Testing accuracy: 69.7
Testing accuracy: 69.625
Testing accuracy: 69.8625
Testing accuracy: 69.65
Testing accuracy: 69.5375
Testing accuracy: 70.025
Testing accuracy: 69.8125

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.