Git Product home page Git Product logo

rl_enhanced_ssl's Introduction

CoViews: Adaptive Augmentation Using Cooperative Views for Enhanced Contrastive Learning

This repository is the official implementation of CoViews: Adaptive Augmentation Using Cooperative Views for Enhanced Contrastive Learning.

In this paper we propose a framework for learning efficient adaptive data augmentation policies for contrastive learning with minimal computational overhead. Our approach continuously generates new data augmentation policies during training and produces effective positives/negatives without any supervision.

Within this framework, we present two methods:

  • IndepViews: which generates augmentation policies used across all views, and
  • CoViews: which generates dependent augmentation policies for each view. This enables us to learn dependencies between the transformations applied to each view and ensures that the augmentation strategies applied to different views complement each other, leading to more meaningful and discriminative representations.

Through extensive experimentation on multiple datasets and contrastive learning frameworks, we demonstrate that our method consistently outperforms baseline solutions and that training with a view-dependent augmentation policy outperforms training with a shared independent policy, showcasing its effectiveness in enhancing contrastive learning performance.

Motivation:

Data augmentation plays a critical role in generating high-quality positive and negative pairs necessary for effective contrastive learning. However:

  • common practices involve using a single augmentation policy repeatedly to generate multiple views, potentially leading to inefficient training pairs due to a lack of cooperation between views.
  • Furthermore, to find the optimal set of augmentations, many existing methods require extensive supervised evaluation, overlooking the evolving nature of the model that may require different augmentations throughout the training.
  • Other approaches train differentiable augmentation generators, thus limiting the use of non-differentiable transformation functions from the literature.

Requirements

To install requirements:

pip install -r requirements.txt

Training

Here are the main arguments to pass in the command line for training:

Argument Description
dataset The training dataset: cifar10, svhn, cifar100, stl10, TinyImagenet
augmentation The augmentation strategy to use: randaugment, random, ppo
two_branches Boolean argument: if set, it employs CoViews; otherwise, it employs IndepViews
randaugment_M Magnitude parameter for RandAugment
ppo_iterations Number of PPO iteration to train a policy network
reward_th Threshold parameter in the Bounded InfoNCE reward function
reward_b Tolerance parameter in the Bounded InfoNCE reward function

To train using CoViews, set the --two_branches flag by running this command:

python train.py --dataset cifar10 --augmentation ppo --two_branches

To train using IndepViews, run this command:

python train.py --dataset cifar10 --augmentation ppo

To train using random augmentations, run this command:

python train.py --dataset cifar10 --augmentation random

To train using random randaugment, run this command:

python train.py --dataset cifar10 --augmentation randaugment --randaugment_M 15

Evaluation

To perform linear evaluation of the CoViews model on cifar10, run:

python eval.py --dataset cifar10 --path ./models/cifar10_CoViews.pt

Results

Linear probe accuracy:

The following table show the top-1 linear probe accuracy on CIFAR-10, CIFAR-100, SVHN, STL10 and TinyImagenet. IndepViews and CoViews constantly outperform baseline solutions. Standard deviations are from 5 different random initializations for the linear head. The deviations are small because the linear probe is robust to the random seed.

CIFAR-10 CIFAR-100 SVHN STL10 TinyImagenet
RandAug M=9 92.61 $\pm$ 0.02 68.54 $\pm$ 0.02 94.27 $\pm$ 0.00 89.98 $\pm$ 0.08 30.41 $\pm$ 0.07
RandAug M=15 93.16 $\pm$ 0.03 70.46 $\pm$ 0.04 95.48 $\pm$ 0.02 91.05 $\pm$ 0.11 31.20 $\pm$ 0.06
RandAug M=27 92.55 $\pm$ 0.07 69.11 $\pm$ 0.04 94.01 $\pm$ 0.01 90.88 $\pm$ 0.07 30.43 $\pm$ 0.03
Random 92.92 $\pm$ 0.04
[model]
69.56 $\pm$ 0.02
[model]
96.52 $\pm$ 0.01
[model]
91.90 $\pm$ 0.03
[model]
30.23 $\pm$ 0.06
[model]
IndepViews (Ours) 93.68 $\pm$ 0.04
[model]
72.14 $\pm$ 0.07
[model]
96.58 $\pm$ 0.03
[model]
93.01 $\pm$ 0.08
[model]
35.07 $\pm$ 0.08
[model]
CoViews (Ours) 93.79 $\pm$ 0.08
[model]
72.28 $\pm$ 0.13
[model]
96.69 $\pm$ 0.07
[model]
93.67 $\pm$ 0.05
[model]
36.29 $\pm$ 0.08
[model]

Inspecting the learned subpolicies (IndepViews Vs CoViews):

Evolution of the probablity of transformations:

The following figure represents a comparison of the evolution of transformation probability in the learned adaptive augmentation policies between IndepViews and CoViews on CIFAR-10 dataset.

PDF Figure

  • IndepViews:
    • Highly variable distribution with frequent spikes.
    • Especially noticeable for transformations like Equalize, Solarize, and Brightness.
    • Variability may indicate instability in the learning process.
    • Despite some transformations seeing increased probabilities over training.
  • CoViews:
    • More consistent distribution with minor fluctuations.
    • Avoids pronounced spikes observed in IndepViews.
    • Certain transformations still receive slightly higher probabilities.
    • More equitable distribution across all transformations.
    • Potentially better suited for learning adaptive augmentation policies.

Co-occurence matrix of the transformations:

The following figure represents a comparison between the co-occurrence matrix of transformations in view 1 and view 2 of both IndepViews and CoViews.

PDF Figure

  • IndepViews:
    • Prioritizes transformations like Equalize, Solarize, and Brightness.
    • Often applies the same transformation to both views.
    • Ensures challenge without information about the other view.
  • CoViews:
    • Shows a balanced and diverse utilization of transformations.
    • Enables various subpolicy combinations.
    • Suggests that heavily relying on highly distorting transformations may not be necessary.
    • Cooperative view generation can effectively create challenging pairs.

rl_enhanced_ssl's People

Contributors

thesun00000 avatar

Stargazers

 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.