Git Product home page Git Product logo

pytorch-dpn-pretrained's Introduction

PyTorch Pretrained Dual Path Networks (DPN)

This repository includes a PyTorch implementation of DualPathNetworks (https://arxiv.org/abs/1707.01629) that works with cypw's pretrained weights.

The code is based upon cypw's original MXNet implementation (https://github.com/cypw/DPNs) with oyam's PyTorch implementation (https://github.com/oyam/pytorch-DPNs) as a reference.

Original testing of these models and all validation was done with torch (0.2.0.post1) and mxnet (0.11.0) pip packages installed. The models have since been updated and tested with Conda installs of PyTorch 1.0 and 1.1.

Pretrained

The model weights have already been converted to PyTorch and hosted at a fixed URL. You can use those pretrained weights by calling the model entrypoint functions with pretrained=True

PyTorch Hub

Models can also be access via the PyTorch Hub API

>>> torch.hub.list('rwightman/pytorch-dpn-pretrained')
['dpn68', ...]
>>> model = torch.hub.load('rwightman/pytorch-dpn-pretrained', 'dpn68', pretrained=True)
>>> model.eval()
>>> output = model(torch.randn(1,3,224,224))

Conversion

If you want to convert the weights yourself, download and untar trained weights files from https://github.com/cypw/DPNs#trained-models into a './pretrained' folder where this code is located.

The weights can be converted by running the conversion script as so:

python convert_from_mxnet.py ./pretrained/ --model dpn107

Results

The following tables contain the validation results (from included validation code) on ImageNet-1K. The DPN models are using the converted weights from the pretrained MXNet models. Also included are results from Torchvision ResNet, DenseNet as well as an InceptionV4 and InceptionResnetV2 port (by Cadene, https://github.com/Cadene/pretrained-models.pytorch) for reference.

All DPN runs at image size above 224x224 are using the mean-max pooling scheme (https://github.com/cypw/DPNs#mean-max-pooling) described by cypw.

Note that results are sensitive to image crop, scaling interpolation, and even the image library used. All image operations for these models are performed with PIL. Bicubic interpolation is used for all but the ResNet models where bilinear produced better results. Results for InceptionV4 and InceptionResnetV2 where better at 100% crop, all other networks being evaluated at their native training resolution use 87.5% crop.

Models with a '*' are using weights that were trained on ImageNet-5k and fine-tuned on ImageNet-1k. The MXNet weights files for these have an '-extra' suffix in their name.

Results @224x224

Model Prec@1 (Err) Prec@5 (Err) #Params Crop
DenseNet121 74.752 (25.248) 92.152 (7.848) 7.98 87.5%
ResNet50 76.130 (23.870) 92.862 (7.138) 25.56 87.5%
DenseNet169 75.912 (24.088) 93.024 (6.976) 14.15 87.5%
DualPathNet68 76.346 (23.654) 93.008 (6.992) 12.61 87.5%
ResNet101 77.374 (22.626) 93.546 (6.454) 44.55 87.5%
DenseNet201 77.290 (22.710) 93.478 (6.522) 20.01 87.5%
DenseNet161 77.348 (22.652) 93.646 (6.354) 28.68 87.5%
DualPathNet68b* 77.528 (22.472) 93.846 (6.154) 12.61 87.5%
ResNet152 78.312 (21.688) 94.046 (5.954) 60.19 87.5%
DualPathNet92 79.128 (20.872) 94.448 (5.552) 37.67 87.5%
DualPathNet98 79.666 (20.334) 94.646 (5.354) 61.57 87.5%
DualPathNet131 79.806 (20.194) 94.706 (5.294) 79.25 87.5%
DualPathNet92* 80.034 (19.966) 94.868 (5.132) 37.67 87.5%
DualPathNet107 80.172 (19.828) 94.938 (5.062) 86.92 87.5%

Results @299x299 (test_time_pool=True for DPN)

Model Prec@1 (Err) Prec@5 (Err) #Params Crop
InceptionV3 77.436 (22.564) 93.476 (6.524) 27.16 87.5%
DualPathNet68 78.006 (21.994) 94.158 (5.842) 12.61 100%
DualPathNet68b* 78.582 (21.418) 94.470 (5.530) 12.61 100%
InceptionV4 80.138 (19.862) 95.010 (4.99) 42.68 100%
DualPathNet92* 80.408 (19.592) 95.190 (4.810) 37.67 100%
DualPathNet92 80.480 (19.520) 95.192 (4.808) 37.67 100%
InceptionResnetV2 80.492 (19.508) 95.270 (4.730) 55.85 100%
DualPathNet98 81.062 (18.938) 95.404 (4.596) 61.57 100%
DualPathNet131 81.208 (18.792) 95.630 (4.370) 79.25 100%
DualPathNet107* 81.432 (18.568) 95.706 (4.294) 86.92 100%

Results @320x320 (test_time_pool=True)

Model Prec@1 (Err) Prec@5 (Err) #Params Crop
DualPathNet68 78.450 (21.550) 94.358 (5.642) 12.61 100%
DualPathNet68b* 78.764 (21.236) 94.726 (5.274) 12.61 100%
DualPathNet92* 80.824 (19.176) 95.570 (4.430) 37.67 100%
DualPathNet92 80.960 (19.040) 95.500 (4.500) 37.67 100%
DualPathNet98 81.276 (18.724) 95.666 (4.334) 61.57 100%
DualPathNet131 81.458 (18.542) 95.786 (4.214) 79.25 100%
DualPathNet107* 81.800 (18.200) 95.910 (4.090) 86.92 100%

pytorch-dpn-pretrained's People

Contributors

andfoy avatar cadene avatar rwightman avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-dpn-pretrained's Issues

assert False when converting dpn107 by not finding 'num_batches_tracked' in 'bn' converting function

my environment is : anaconda3 with MXNET 1.1.0, pytorch 0.4.1
while running "python convert_from_mxnet.py ./pretrained --model dpn107", got

Traceback (most recent call last):
File "convert_from_mxnet.py", line 136, in
main()
File "convert_from_mxnet.py", line 129, in main
convert_from_mxnet(model, checkpoint_base)
File "convert_from_mxnet.py", line 51, in convert_from_mxnet
aux, key_add = _convert_bn(k[3])
File "convert_from_mxnet.py", line 31, in _convert_bn
assert False
AssertionError

Then I tried to print more info, got:

[17:04:05] src/nnvm/legacy_json_util.cc:190: Loading symbol saved by previous version v0.8.0. Attempting to upgrade...
[17:04:05] src/nnvm/legacy_json_util.cc:198: Symbol successfully upgraded!
db now layer i:2,['features', 'conv1_1', 'bn', 'weight']
db now layer i:3,['features', 'conv1_1', 'bn', 'bias']
db now layer i:4,['features', 'conv1_1', 'bn', 'running_mean']
db now layer i:5,['features', 'conv1_1', 'bn', 'running_var']
db now layer i:6,['features', 'conv1_1', 'bn', 'num_batches_tracked']
db in _convert_bn, k is num_batches_tracked
Traceback (most recent call last):
File "convert_from_mxnet.py", line 136, in
main()
File "convert_from_mxnet.py", line 129, in main
convert_from_mxnet(model, checkpoint_base)
File "convert_from_mxnet.py", line 51, in convert_from_mxnet
aux, key_add = _convert_bn(k[3])
File "convert_from_mxnet.py", line 31, in _convert_bn
assert False
AssertionError

How to fix this?

Convert ResNet from MXNet to PyTorch

Dear @rwightman,
Thank you for your nice repository. I have a pre-trained ResNet152 model on MXNet and I want to convert it to PyTorch. Would you please kindly guide me to do that?
In fact, I have used convert_from_mxnet.py to do that (by some little modification), however, I have faced the following error:
AssertionError: Unexpected token

Error with torch.hub.load when using num_classes kwarg

When trying to access dpn107 via the PyTorch Hub API, I run into this error:

Error(s) in loading state_dict for DPN: size mismatch for classifier.weight: copying a param with shape torch.Size([1000, 2688, 1, 1]) from checkpoint, the shape in current model is torch.Size([16, 2688, 1, 1]). size mismatch for classifier.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([16]).

The call I used was:
torch.hub.load('rwightman/pytorch-dpn-pretrained', 'dpn107', num_classes=16, pretrained=True)

When specifying the num_classes arg, this error appears.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.