Git Product home page Git Product logo

pytorch-playground's Introduction

This is a playground for pytorch beginners, which contains predefined models on popular dataset. Currently we support

  • mnist, svhn
  • cifar10, cifar100
  • stl10
  • alexnet
  • vgg16, vgg16_bn, vgg19, vgg19_bn
  • resent18, resent34, resnet50, resnet101, resnet152
  • squeezenet_v0, squeezenet_v1
  • inception_v3

Here is an example for MNIST dataset. This will download the dataset and pre-trained model automatically.

import torch
from torch.autograd import Variable
from utee import selector
model_raw, ds_fetcher, is_imagenet = selector.select('mnist')
ds_val = ds_fetcher(batch_size=10, train=False, val=True)
for idx, (data, target) in enumerate(ds_val):
    data =  Variable(torch.FloatTensor(data)).cuda()
    output = model_raw(data)

Also, if want to train the MLP model on mnist, simply run python mnist/train.py

Install

  • pytorch (>=0.1.11) and torchvision from official website, for example, cuda8.0 for python3.5
    • pip install http://download.pytorch.org/whl/cu80/torch-0.1.12.post2-cp35-cp35m-linux_x86_64.whl
    • pip install torchvision
  • tqdm
    • pip install tqdm
  • OpenCV
    • conda install -c menpo opencv3
  • Setting PYTHONPATH
    • export PYTHONPATH=/path/to/pytorch-playground:$PYTHONPATH

ImageNet dataset

We provide precomputed imagenet validation dataset with 224x224x3 size. We first resize the shorter size of image to 256, then we crop 224x224 image in the center. Then we encode the cropped images to jpg string and dump to pickle.

  • cd script
  • Download the val224_compressed.pkl
    • axel http://ml.cs.tsinghua.edu.cn/~chenxi/dataset/val224_compressed.pkl
  • python convert.py

Quantization

We also provide a simple demo to quantize these models to specified bit-width with several methods, including linear method, minmax method and non-linear method.

python quantize.py --type cifar10 --quant_method linear --param_bits 8 --fwd_bits 8 --bn_bits 8 --ngpu 1

Top1 Accuracy

We evaluate the performance of popular dataset and models with linear quantized method. The bit-width of running mean and running variance in BN are 10 bits for all results.

Model 32-float 12-bit 10-bit 8-bit 6-bit
MNIST 98.42% 98.43% 98.44% 98.44% 98.32
SVHN 96.03% 96.03% 96.04% 96.02% 95.46%
CIFAR10 93.78% 93.79% 93.80% 93.58% 90.86%
CIFAR100 74.27% 74.21% 74.19% 73.70% 66.32%
STL10 77.59% 77.65% 77.70% 77.59% 73.40%
AlexNet 55.70% 55.66% 55.54% 54.17% 18.19%
VGG16 70.44% 70.45% 70.44% 69.99% 53.33%
VGG19 71.36% 71.35% 71.34% 70.88% 56.00%
ResNet18 68.63% 68.62% 68.49% 66.80% 19.14%
ResNet34 72.50% 72.46% 72.45% 71.47% 32.25%
ResNet50 74.98% 74.94% 74.91% 72.54% 2.43%
ResNet101 76.69% 76.66% 76.22% 65.69% 1.41%
ResNet152 77.55% 77.51% 77.40% 74.95% 9.29%
SqueezeNetV0 56.73% 56.75% 56.70% 53.93% 14.21%
SqueezeNetV1 56.52% 56.52% 56.24% 54.56% 17.10%
InceptionV3 76.41% 76.43% 76.44% 73.67% 1.50%

Note: ImageNet 32-float models are directly from torchvision

Selected Arguments

Here we give an overview of selected arguments of quantize.py

Flag Default value Description & Options
type cifar10 mnist,svhn,cifar10,cifar100,stl10,alexnet,vgg16,vgg16_bn,vgg19,vgg19_bn,resent18,resent34,resnet50,resnet101,resnet152,squeezenet_v0,squeezenet_v1,inception_v3
quant_mothod linear quantization method:linear,minmax,log,tanh
param_bits 8 bit-width of weights and bias
fwd_bits 8 bit-width of activation
bn_bits 32 bit-width of running mean and running vairance
overflow_rate 0.0 overflow rate threshold for linear quantization method
n_samples 20 number of samples to make statistics for activation

pytorch-playground's People

Contributors

aaron-xichen avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.