Git Product home page Git Product logo

genscl's Introduction

A Generalized Supervised Contrastive Learning Framework

f1

Official PyTorch implementation of the GenSCL | Paper

Jaewon Kim and Jooyoung Chang and Sang Min Park

HSDS@Seoul National University

Our implementation is based on the Supervised Contrastive Learning repository.

Abstract

Based on recent remarkable achievements of contrastive learning in self-supervised representation learning, supervised contrastive learning (SupCon) has successfully extended the batch contrastive approaches to the supervised context and outperformed cross-entropy on various datasets on ResNet. In this work, we present GenSCL: a generalized supervised contrastive learning framework that seamlessly adapts modern image-based regularizations (such as Mixup-Cutmix) and knowledge distillation (KD) to SupCon by our generalized supervised contrastive loss. Generalized supervised contrastive loss is a further extension of supervised contrastive loss measuring cross-entropy between the similarity of labels and that of latent features. Then a model can learn to what extent contrastives should be pulled closer to an anchor in the latent space. By explicitly and fully leveraging label information, GenSCL breaks the boundary between conventional positives and negatives, and any kind of pre-trained teacher classifier can be utilized. ResNet-50 trained in GenSCL with Mixup-Cutmix and KD achieves state-of-the-art accuracies of 97.6% and 84.7% on CIFAR10 and CIFAR100 without external data, which significantly improves the results reported in the original SupCon (1.6% and 8.2%, respectively). Pytorch implementation is available at https://t.ly/yuUO.

Overview of the results

results

Loss Function

Our proposed Generalized Supervised Contrastive Loss in loss.py takes a tuple of features and a tuple of labels as the input, and returns the loss. If labels is one-hot encoded label, it degenerates to Supervised Contrastive Loss.

By Generalized Supervised Contrastive Loss, we can seamlessly adapt Mixup/Cutmix and knowledge distillation to Supervised Contrastive Learning.

ex

Running

To apply knowledge distillation, pretrained teacher model (EfficientNetV2-M) is required and released here.

  • CIFAR10

    • Pretraining stage:

      python genscl.py \
      --dataset cifar10 \
      --mix mixup_cutmix \
      --KD \
      --KD-alpha 1 \
      --teacher-path ./pretrained_saves/efficientnetv2_rw_m_ema_mixup_cutmix_cifar10_Adam
    • Linear evaluation stage:

      python linear.py \
      --dataset cifar10 \
      --pretrained cifar10_bsz_1024_mixup_cutmix_1.0_KD_1.0_SGD_lr_0.5 \
      --augment-policy no \
      --amp
  • CIFAR100

    • Pretraining stage:

      python genscl.py \
      --dataset cifar100 \
      --mix mixup_cutmix \
      --KD \
      --KD-alpha 1 \
      --teacher-path ./pretrained_saves/efficientnetv2_rw_m_ema_mixup_cutmix_cifar100_Adam
    • Linear evaluation stage:

      python linear.py \
      --dataset cifar100 \
      --pretrained cifar100_bsz_1024_mixup_cutmix_1.0_KD_1.0_SGD_lr_0.5 \
      --augment-policy no \
      --amp

You have several extra options:

  • --optim-kind: SGD, RMSProp, Adam, AdamW

  • --augment-policy: no, sim, auto, rand

  • --wandb: enable wandb for visualization

Updates

  • 23 Jun, 2022: Initial upload

Citation

@article{kim2022generalized,
  title={A Generalized Supervised Contrastive Learning Framework},
  author={Kim, Jaewon and Chang, Jooyoung and Park, Sang Min},
  journal={arXiv preprint arXiv:2206.00384},
  year={2022}
}

genscl's People

Contributors

kiimmm avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.