Git Product home page Git Product logo

pytorch-image-models's Introduction

PyTorch Image Models, etc

Introduction

For each competition, personal, or freelance project involving images + Convolution Neural Networks, I build on top of an evolving collection of code and models. This repo contains a (somewhat) cleaned up and paired down iteration of that code. Hopefully it'll be of use to others.

The work of many others is present here. I've tried to make sure all source material is acknowledged:

Models

I've included a few of my favourite models, but this is not an exhaustive collection. You can't do better than Cadene's collection in that regard. Most models do have pretrained weights from their respective sources or original authors.

Use the --model arg to specify model for train, validation, inference scripts. Match the all lowercase creation fn for the model you'd like.

Features

Several (less common) features that I often utilize in my projects are included. Many of their additions are the reason why I maintain my own set of models, instead of using others' via PIP:

  • All models have a common default configuration interface and API for
    • accessing/changing the classifier - get_classifier and reset_classifier
    • doing a forward pass on just the features - forward_features
    • these makes it easy to write consistent network wrappers that work with any of the models
  • All models have a consistent pretrained weight loader that adapts last linear if necessary, and from 3 to 1 channel input if desired
  • The train script works in several process/GPU modes:
    • NVIDIA DDP w/ a single GPU per process, multiple processes with APEX present (AMP mixed-precision optional)
    • PyTorch DistributedDataParallel w/ multi-gpu, single process (AMP disabled as it crashes when enabled)
    • PyTorch w/ single GPU single process (AMP optional)
  • A dynamic global pool implementation that allows selecting from average pooling, max pooling, average + max, or concat([average, max]) at model creation. All global pooling is adaptive average by default and compatible with pretrained weights.
  • A 'Test Time Pool' wrapper that can wrap any of the included models and usually provide improved performance doing inference with input images larger than the training size. Idea adapted from original DPN implementation when I ported (https://github.com/cypw/DPNs)
  • Training schedules and techniques that provide competitive results (Cosine LR, Random Erasing, Label Smoothing, etc)
  • Mixup (as in https://arxiv.org/abs/1710.09412) - currently implementing/testing
  • An inference script that dumps output to CSV is provided as an example

Self-trained Weights

I've leveraged the training scripts in this repository to train a few of the models with missing weights to good levels of performance. These numbers are all for 224x224 training and validation image sizing with the usual 87.5% validation crop.

@ 224x224

Model Prec@1 (Err) Prec@5 (Err) Param # Image Scaling
resnext50_32x4d 78.512 (21.488) 94.042 (5.958) 25M bicubic
seresnext26_32x4d 77.104 (22.896) 93.316 (6.684) 16.8M bicubic
efficientnet_b0 76.912 (23.088) 93.210 (6.790) 5.29M bicubic
mobilenetv3_100 75.634 (24.366) 92.708 (7.292) 5.5M bicubic
fbnetc_100 75.124 (24.876) 92.386 (7.614) 5.6M bilinear
resnet34 75.110 (24.890) 92.284 (7.716) 22M bilinear
seresnet34 74.808 (25.192) 92.124 (7.876) 22M bilinear
spnasnet_100 74.084 (25.916) 91.818 (8.182) 4.42M bilinear
seresnet18 71.742 (28.258) 90.334 (9.666) 11.8M bicubic

Ported Weights

@ 224x224

Model Prec@1 (Err) Prec@5 (Err) Param # Image Scaling Source
gluon_senet154 81.224 (18.776) 95.356 (4.644) 115.09 bicubic
gluon_resnet152_v1s 81.012 (18.988) 95.416 (4.584) 60.32 bicubic
gluon_seresnext101_32x4d 80.902 (19.098) 95.294 (4.706) 48.96 bicubic
gluon_seresnext101_64x4d 80.890 (19.110) 95.304 (4.696) 88.23 bicubic
gluon_resnext101_64x4d 80.602 (19.398) 94.994 (5.006) 83.46 bicubic
gluon_resnet152_v1d 80.470 (19.530) 95.206 (4.794) 60.21 bicubic
gluon_resnet101_v1d 80.424 (19.576) 95.020 (4.980) 44.57 bicubic
gluon_resnext101_32x4d 80.334 (19.666) 94.926 (5.074) 44.18 bicubic
gluon_resnet101_v1s 80.300 (19.700) 95.150 (4.850) 44.67 bicubic
gluon_resnet152_v1c 79.916 (20.084) 94.842 (5.158) 60.21 bicubic
gluon_seresnext50_32x4d 79.912 (20.088) 94.818 (5.182) 27.56 bicubic
gluon_resnet152_v1b 79.692 (20.308) 94.738 (5.262) 60.19 bicubic
gluon_resnet101_v1c 79.544 (20.456) 94.586 (5.414) 44.57 bicubic
gluon_resnext50_32x4d 79.356 (20.644) 94.424 (5.576) 25.03 bicubic
gluon_resnet101_v1b 79.304 (20.696) 94.524 (5.476) 44.55 bicubic
gluon_resnet50_v1d 79.074 (20.926) 94.476 (5.524) 25.58 bicubic
gluon_resnet50_v1s 78.712 (21.288) 94.242 (5.758) 25.68 bicubic
gluon_resnet50_v1c 78.010 (21.990) 93.988 (6.012) 25.58 bicubic
gluon_resnet50_v1b 77.578 (22.422) 93.718 (6.282) 25.56 bicubic
tf_efficientnet_b0 *tfp 76.828 (23.172) 93.226 (6.774) 5.29 bicubic Google
tf_efficientnet_b0 76.528 (23.472) 93.010 (6.990) 5.29 bicubic Google
gluon_resnet34_v1b 74.580 (25.420) 91.988 (8.012) 21.80 bicubic
tflite_semnasnet_100 73.086 (26.914) 91.336 (8.664) 3.87 bicubic Google TFLite
tflite_mnasnet_100 72.398 (27.602) 90.930 (9.070) 4.36 bicubic Google TFLite
gluon_resnet18_v1b 70.830 (29.170) 89.756 (10.244) 11.69 bicubic

@ 240x240

Model Prec@1 (Err) Prec@5 (Err) Param # Image Scaling Source
tf_efficientnet_b1 *tfp 78.796 (21.204) 94.232 (5.768) 7.79 bicubic Google
tf_efficientnet_b1 78.554 (21.446) 94.098 (5.902) 7.79 bicubic Google

@ 260x260

Model Prec@1 (Err) Prec@5 (Err) Param # Image Scaling Source
tf_efficientnet_b2 *tfp 79.782 (20.218) 94.800 (5.200) 9.11 bicubic Google
tf_efficientnet_b2 79.606 (20.394) 94.712 (5.288) 9.11 bicubic Google

@ 299x299 and 300x300

Model Prec@1 (Err) Prec@5 (Err) Param # Image Scaling Source
tf_efficientnet_b3 *tfp 80.982 (19.018) 95.332 (4.668) 12.23 bicubic Google
tf_efficientnet_b3 80.874 (19.126) 95.302 (4.698) 12.23 bicubic Google
gluon_inception_v3 78.804 (21.196) 94.380 (5.620) 27.16M bicubic MxNet Gluon
tf_inception_v3 77.856 (22.144) 93.644 (6.356) 27.16M bicubic Tensorflow Slim
adv_inception_v3 77.576 (22.424) 93.724 (6.276) 27.16M bicubic Tensorflow Adv models

NOTE: For some reason I can't hit the stated accuracy with my impl of MNASNet and Google's tflite weights. Using a TF equivalent to 'SAME' padding was important to get > 70%, but something small is still missing. Trying to train my own weights from scratch with these models has so far to leveled off in the same 72-73% range.

Models with *tfp next to them were scored with --tf-preprocessing flag.

The tf_efficientnet and tflite_(se)mnasnet models require an equivalent for 'SAME' padding as their arch results in asymmetric padding. I've added this in the model creation wrapper, but it does come with a performance penalty.

Script Usage

Training

The variety of training args is large and not all combinations of options (or even options) have been fully tested. For the training dataset folder, specify the folder to the base that contains a train and validation folder.

To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process per GPU w/ cosine schedule, random-erasing prob of 50% and per-pixel random value:

./distributed_train.sh 4 /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 -j 4

NOTE: NVIDIA APEX should be installed to run in per-process distributed via DDP or to enable AMP mixed precision with the --amp flag

Validation / Inference

Validation and inference scripts are similar in usage. One outputs metrics on a validation set and the other outputs topk class ids in a csv. Specify the folder containing validation images, not the base as in training script.

To validate with the model's pretrained weights (if they exist):

python validate.py /imagenet/validation/ --model seresnext26_32x4d --pretrained

To run inference from a checkpoint:

python inference.py /imagenet/validation/ --model mobilenetv3_100 --checkpoint ./output/model_best.pth.tar

TODO

A number of additions planned in the future for various projects, incl

  • Find optimal training hyperparams and create/port pretraiend weights for the generic MobileNet variants
  • Do a model performance (speed + accuracy) benchmarking across all models (make runable as script)
  • More training experiments
  • Make folder/file layout compat with usage as a module
  • Add usage examples to comments, good hyper params for training
  • Comments, cleanup and the usual things that get pushed back

pytorch-image-models's People

Contributors

rwightman avatar zhunzhong07 avatar cclauss avatar

Watchers

James Cloos avatar paper2code - bot avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.