Git Product home page Git Product logo

fastswa-semi-sup's Introduction

Improving Consistency-Based Semi-Supervised Learning with Weight Averaging

This repository contains the code for our paper Improving Consistency-Based Semi-Supervised Learning with Weight Averaging, which achieves the best known performance for semi-supervised learning on CIFAR-10 and CIFAR-100. By analyzing the geometry of training objectives involving consistency regularization, we can significantly improve the Mean Teacher model (Tarvainen and Valpola, NIPS 2017) and the Pi Model (Laine and Aila, ICLR 2017), using stochastic weight averaging (SWA) and our proposed variant fast-SWA which achieves faster reduction in prediction errors.

The BibTeX entry for the paper is:

@article{athiwaratkun2018improving,
  title={Improving Consistency-Based Semi-Supervised Learning with Weight Averaging},
  author={Athiwaratkun, Ben and Finzi, Marc and Izmailov, Pavel and Wilson, Andrew Gordon},
  journal={arXiv preprint arXiv:1806.05594},
  year={2018}
}

Preparing Packages and Data

The code runs on Python 3 with Pytorch 0.3. The following packages are also required.

pip install scipy tqdm matplotlib pandas msgpack

Then prepare CIFAR-10 and CIFAR-100 with the following commands:

./data-local/bin/prepare_cifar10.sh
./data-local/bin/prepare_cifar100.sh

Semi-Supervised Learning with fastSWA

We provide training scripts in folder exps. To replicate the results for CIFAR-10 using the Mean Teacher model on 4000 labels with a 13-layer CNN, run the following:

python experiments/cifar10_mt_cnn_short_n4k.py

Similarly, for CIFAR-100 with 10k labels:

python experiments/cifar100_mt_cnn_short_n10k.py

The results are saved to the directories cifar10_mt_cnn_short_n4k and results/cifar100_mt_cnn_short_n10k. The plot the accuracy versus epoch, run

python read_log.py --pattern results/cifar10_*n4k --cutoff 84 --interval 2 --upper 92
python read_log.py --pattern results/cifar100_*n10k --cutoff 54 --interval 4 --upper 70

Figure 1. CIFAR-10 Test Accuracy of the Mean Teacher Model with SWA and fastSWA using a 13-layer CNN and 4000 labels.

Figure 2. CIFAR-100 Test Accuracy of the Mean Teacher Model with SWA and fastSWA using a 13-layer CNN and 10000 labels.

We provide scripts for ResNet-26 with Shake-Shake regularization for CIFAR-10 and CIFAR-100, as well as other label settings in the directory experiments.

fastSWA and SWA Implementation

fastSWA can be incorporated into training very conveniently with just a few lines of code. First, we initialize a replicate model (which can be set to require no gradients to save memory) and initialize the weight averaging optimization object.

fastswa_net = create_model(no_grad=True)
fastswa_net_optim = optim_weight_swa.WeightSWA(swa_model)

Then, the fastSWA model can be updated every fastswa_freq epochs. Note that after updating the weights, we need to update Batch Normalization running average variables by passing the training data through the fastSWA model.

if epoch >= (args.epochs - args.cycle_interval) and (epoch - args.epochs - args.cycle_interval) % fastswa_freq == 0:
  fastswa_net_optim.update(fastswa_net)
  update_batchnorm(fastswa_net, train_loader, train_loader_len)

For more details, see main.py.

Note: the code is adapted from https://github.com/CuriousAI/mean-teacher/tree/master/pytorch

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.