Git Product home page Git Product logo

daso's Introduction

DASO: Distribution-Aware Semantics-Oriented Pseudo-label for Imbalanced Semi-Supervised Learning.

This repo provides the official pytorch implementation of our paper:

DASO: Distribution-Aware Semantics-Oriented Pseudo-label for Imbalanced Semi-Supervised Learning

Youngtaek Oh, Dong-Jin Kim, and In So Kweon

CVPR 2022

[Paper] [Project Page]

This includes the experiments on CIFAR-10, CIFAR-100, and STL-10 datasets with varying number of data-points and imbalance ratio in semi-supervised learning.

Abstract

The capability of the traditional semi-supervised learning (SSL) methods is far from real-world application due to severely biased pseudo-labels caused by (1) class imbalance and (2) class distribution mismatch between labeled and unlabeled data.

This paper addresses such a relatively under-explored problem. First, we propose a general pseudo-labeling framework that class-adaptively blends the semantic pseudo-label from a similarity-based classifier to the linear one from the linear classifier, after making the observation that both types of pseudo-labels have complementary properties in terms of bias. We further introduce a novel semantic alignment loss to establish balanced feature representation to reduce the biased predictions from the classifier. We term the whole framework as Distribution-Aware Semantics-Oriented (DASO) Pseudo-label.

We conduct extensive experiments in a wide range of imbalanced benchmarks: CIFAR10/100-LT, STL10-LT, and large-scale long-tailed Semi-Aves with open-set class, and demonstrate that, the proposed DASO framework reliably improves SSL learners with unlabeled data especially when both (1) class imbalance and (2) distribution mismatch dominate.

Figure 2: Analysis on recall and precision of pseudo-labels and the corresponding test accuracy obtained from FixMatch, USADTM, and DASO (Ours).

Installation

The code is tested with:

  • CUDA 10.1
  • Python 3.6
  • PyTorch 1.4.0
  • Single Titan Xp GPU (Mem: 12GB)

To install requirements:

(Option 1) Install with conda

#  installing environments via conda
conda create -n daso python=3.6 -y
conda activate daso

# installing required library
conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.0 -c pytorch  # pytorch for CUDA 10.0

pip install yacs
pip install future tensorboard

git clone https://github.com/ytaek-oh/daso.git && cd daso

(Option 2) Install with Docker

git clone https://github.com/ytaek-oh/daso.git && cd daso
bash docker/run_docker.sh  # see Dockfile therein if needed.

Code Structure

  • configs/: includes all the pre-defined configuration files for each baseline method and our DASO across all the benchmarks.
  • lib/algorithm/: includes the implementations of each baseline method.
  • lib/config/: holds all the configurations that can be manipulated.

Training

  • As note, each dataset is automatically downloaded to ./data/ when you firstly run the training.

  • You may refer to the pre-defined configuration files on ./configs/, or all the available default parameters are defined in lib/config/defaults.py.

To train the model(s) in the paper, run the command as:

python main.py --config-file {PATH_TO_CONFIG_FILE} {OTHER OPTIONS}

Example (1):

# run FixMatch + DASO on CIFAR10-LT with \N_1=500, \M_1=4000 and \gamma_l = \gamma_u = 100.
python main.py --config-file configs/cifar10/fixmatch_daso.yaml \
                DATASET.CIFAR10.NUM_LABELED_HEAD 500 DATASET.CIFAR10.NUM_UNLABELED HEAD 4000 \
                # Below is optional, if required
                SEED {SEED} GPU_ID {DEVICE} \
                {OTHER OPTIONS}

Example (2):

# run FixMatch + DASO under CIFAR100-LT with \N_1=150, \gamma_l = \gamma_u = 20.
python main.py --config-file configs/cifar100/fixmatch_daso.yaml \
                DATASET.CIFAR100.IMB_FACTOR_L 20 DATASET.CIFAR100.IMB_FACTOR_UL 20 \
                {OTHER OPTIONS}

Example (3):

# run FixMatch + DASO under STL10-LT with \N_1=150, \gamma_l = 20.
python main.py --config-file configs/stl10/fixmatch_daso.yaml \
                DATASET.STL10.NUM_LABELED_HEAD 150 DATASET.STL10.IMB_FACTOR_L 20 \
                {OTHER OPTIONS}

A commonly used template for the commands would be:

python main.py --config-file configs/{DATASET}/{ALGORITHM}.yaml \
        DATASET.{DATASET}.NUM_LABELED_HEAD {VALUE} DATASET.{DATASET}.NUM_UNLABELED_HEAD {VALUE} \
        DATASET.{DATASET}.IMB_FACTOR_L {VALUE} DATASET.{DATASET}.IMB_FACTOR_UL {VALUE} \
        (ALGORITHM.{ALGORITHM}.{OPTIONS} {VALUE} ...)
  • For data-specific setups and hyper-parameters in detail, please refer to the Implementation Details section in Appendix of our paper.

Monitoring Performance

  • Final configuration file, checkpoints, and all the logged results and metrics with tensorboard are produced in the path {OUTPUT_DIR} designated in the config file.

Citation

If you find our work useful for your research, please cite with the following bibtex:

@inproceedings{oh2021distribution,
  title = {DASO: Distribution-aware semantics-oriented pseudo-label for imbalanced semi-supervised learning},
  author = {Oh, Youngtaek and Kim, Dong-Jin and Kweon, In So},
  booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition (CVPR)},
  year = {2022},
  pages = {9786-9796}
}

daso's People

Contributors

ytaek-oh avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

daso's Issues

Request to put code

Hello.

Can you please upload the code for the paper: - Distribution-Aware Semantics-Oriented Pseudo-label for Imbalanced Semi-Supervised Learning

Exact same wandb plots for STL-10 Crest/Crest+

For the same seed, even after changing the ALGORITHM.CREST.PROGRESSIVE_ALIGN argument to False (for Crest) the plots obtained for both Crest and Crest+ are the same. This happens only for the STL-10 dataset, and not for CIFAR 10/100.

About the accuracy

Hi. Thanks for the great work here.

But I can't reproduce the result you reported in Table 2(Ours Îłu = 1/100 (reversed) N1 = 1500 M1 = 3000).I use the config in configs/cifar10/fixmatch_daso,and my command is python main.py --config-file configs/cifar10/fixmatch_daso.yaml \ DATASET.CIFAR10.NUM_LABELED_HEAD 1500 DATASET.CIFAR10.NUM_UNLABELED_HEAD 3000 DATASET.REVERSE_UL_DISTRIBUTION True

The result I get in result.json is around 77.9,three points lower than your result.

FixMatch accuracy on CIFAR10-LT

Hi,

Thanks for the great work. I had a concern regarding the performance of FixMatch on CIFAR-10-LT in Table 1. With N1=1500, M1=3000, DARP reported 71.5% and 68.5% test accuracy for r=100 and 150, respectively. But your numbers are 77.5% and 72.4%, respectively. Could you please explain what is causing the performance difference here?

Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.