Git Product home page Git Product logo

dl_practice's Introduction

Practice: Classification

In this practice, you will train a simple neural network classifier and play with many training tricks.

Getting started

Before we start this practice, you need to understand how PyTorch framework (tensor, gradient, network, loss function, optimizer) works. Please refer to Deep Learning with PyTorch: A 60 Minute Blitz and examples.

Prerequisites

  • python 3.6.9
  • CUDA 10.2
  • Pytorch 1.4.0
  • numpy 1.18.3
  • opencv 4.1.2
pip3 install torch
pip3 install torchvision  

Dataset

CIFAR-100

The CIFAR-100 dataset consists of 60000 32x32 colour images in 100 classes, with 600 images per class. There are 50000 training images and 10000 test images.

You can simply download the data with the torchvision API

from torchvision import datasets, transforms

BATCH_SIZE = 64

train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR100('./data', train=True, download=True, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])), 
    batch_size=BATCH_SIZE, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR100('./data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])), 
    batch_size=BATCH_SIZE, shuffle=True)

Problem sets

Let’s train different models for recognizing CIFAR-100 classes! Please compare the convergence time and test accuracy.

  1. Build a softmax classification model with a single linear layer using stochastic gradient descent (SGD).

  2. Build a 1-hidden layer neural network with 1024 ReLU units using SGD. This model should improve your test accuracy.

  3. Try to get better performance by adding more layers and using learning rate decay.

  4. Build a convolutional neural network with two convolutional layers, followed by one fully connected layer using Adam optimizer. (Both Conv1 and Conv2: 16@5x5 filters at stride 2)

  5. Now, please replace the strides by a max pooling operation of stride 2, kernel size 2.

  6. Apply dropout to the hidden layer of your models. Note that dropout should only be introduced during training, not evaluation.

  7. Load ResNet18 pre-trained model and finetune on the CIFAR-100 dataset.

  8. Train ResNet18 from scratch and compare the result to problem 7.

[Optional]

  1. Apply batch normalization to your models.

  2. Replace the ReLU units in your models by LeakyReLU or SELU.

Saving your model

It is important to save your model at any time, especially when you want to reproduce your results or contiune the training procedure. One can easily save the model and the parapeters by using the save/load functions. While please also note that when you need to resume training, you should follow this example to save all required state dictionaries.

Now please save the model which achieves best performance among the above variants. Try to reproduce your results using the save/load functions instead of running a new training procedure.

Bonus

If you have time, you can use the technique of transfer learning to achieve better performance of semantic segmentation. Detailed discription is in Segmentation Practice.

dl_practice's People

Contributors

azuxmioy avatar cw1204772 avatar tsaishien-chen avatar

Stargazers

Shuyu Lin avatar  avatar  avatar Chih-Ting Liu (劉致廷) avatar  avatar  avatar Po-Chen Wu avatar

Watchers

James Cloos avatar Yu-Sheng Lin avatar Wei-Chih Tu avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.