Git Product home page Git Product logo

swa-tutorials-pytorch's Introduction

swa-tutorials-pytorch

Stochastic Weight Averaging Tutorials using pytorch. Based on PyTorch 1.6 Official Features (Stochastic Weight Averaging), implement classification codebase using custom dataset.

  • author: hoya012
  • last update: 2020.10.23

0. Experimental Setup

0-1. Prepare Library

  • Need to install PyTorch and Captum
pip install -r requirements.txt

0-2. Download dataset (Kaggle Intel Image Classification)

This Data contains around 25k images of size 150x150 distributed under 6 categories. {'buildings' -> 0, 'forest' -> 1, 'glacier' -> 2, 'mountain' -> 3, 'sea' -> 4, 'street' -> 5 }

  • Make data folder and move dataset into data folder.

1. Baseline Training

  • ImageNet Pretrained ResNet-18 from torchvision.models
  • Batch Size 256 / Epochs 120 / Initial Learning Rate 0.0001
  • Training Augmentation: Resize((256, 256)), RandomHorizontalFlip()
  • Adam + Cosine Learning rate scheduling with warmup
  • I tried NVIDIA Pascal GPU - GTX 1080 Ti 1 GPU
python main.py --checkpoint_name baseline;

2. Stochastic Weight Averaging Training

In PyTorch 1.6, Stochastic Weight Averaging is very easy to use! Thanks to PyTorch..

  • PyTorch's official tutorial's guide
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR

loader, optimizer, model, loss_fn = ...
swa_model = AveragedModel(model)
scheduler = CosineAnnealingLR(optimizer, T_max=100)
swa_start = 5
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

for epoch in range(100):
      for input, target in loader:
          optimizer.zero_grad()
          loss_fn(model(input), target).backward()
          optimizer.step()
      if epoch > swa_start:
          swa_model.update_parameters(model)
          swa_scheduler.step()
      else:
          scheduler.step()

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data 
preds = swa_model(test_input)
  • My own implementations
# in main.py
""" define model and learning rate scheduler for stochastic weight averaging """
swa_model = torch.optim.swa_utils.AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=args.swa_lr)

...  

# in learning/trainer.py
for batch_idx, (inputs, labels) in enumerate(data_loader):
  if not args.decay_type == 'swa':
        self.scheduler.step()
  else:
      if epoch <= args.swa_start:
          self.scheduler.step()

if epoch > args.swa_start and args.decay_type == 'swa':
  self.swa_model.update_parameters(self.model)
  self.swa_scheduler.step()

...

# in main.py
swa_model = swa_model.cpu()
torch.optim.swa_utils.update_bn(train_loader, swa_model)
swa_model = swa_model.cuda() 

Run Script (Command Line)

python main.py --checkpoint_name swa --decay_type swa --swa_start 90 --swa_lr 5e-5;

3. Performance Table

  • B : Baseline
  • SWA : Stochastic Weight Averaging
    • SWA_{swa_start}_{swa_lr}
Algorithm Test Accuracy
B 94.10
SWA_90_0.05 80.53
SWA_90_1e-4 94.20
SWA_90_5e-4 93.87
SWA_90_1e-5 94.23
SWA_90_5e-5 94.57
SWA_75_5e-5 94.27
SWA_60_5e-5 94.33

4. Code Reference

swa-tutorials-pytorch's People

Contributors

hoya012 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.