Git Product home page Git Product logo

mobileone's Introduction

MobileOne: An Improved One millisecond Mobile Backbone

This is an un-official implementation of Paper An Improved One millisecond Mobile Backbone, whose performance is close to the paper.

Actually I achieve the MobileOne' s0 architecture (the smallest one), and validate on ImageNet-1000k dataset. And a val accuracy is here.

mobile-one block

Model before merging blocks after merging blocks FLOPS
origin paper (s0) none 71.4 275M
my implementation (s0) 70.470 70.518 274M

Note that I only train a "s0" version, but you can easily modify the code to train other version, please see "mobileone.py" to refer to a configuration.

Different from original paper, We don't use:

  • AutoAugment. In fact, S0 indeed abandons autoaugment.
  • annealed weight decay. I set it constantly 4e-5.
  • label smoothing regularization
  • EMA update strategy
  • progressive learning curriculum. I directly use 224px to train.
  • Custom Weight decay Loss. I directly use WeightDecay in optimizer.

very Important

I thank grygielski for finding a bug of my implementation, which is a very easy-correct mistake. please refer to this issue

Because I don't plan to retrain the model, so I don't want to rewrite the code. For the mistake, you can just delete the "if" condition, like this:

self.dw_bn_layer = nn.BatchNorm2d(in_channels) if out_channels == in_channels and stride == 1 else None

to

self.dw_bn_layer = nn.BatchNorm2d(in_channels)

For validation

I release a pretrained model weight, click here to download. The test script validates the trained model, and also generates a converted deploy model.

python test.py {your imagenet-1000k dataset path} deploy mobileone_s0_hello_best.pth.tar

A converted deploy model is generated at "mobileone_deploy_model.pt" file.

For train

I train the mobileone-s0 on 8 32G-V100 GPUS, costing about 4 days.

python train.py -a mobileone_s0 --dist-url 'tcp://127.0.0.1:23333' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 --workers 32 {your imagenet-1000k dataset path} --tag hello --wd 4e-5

For use

model = make_mobileone_s0(deploy=True)
model.load(torch.load('mobileone_depoly_model.pt'))
model.cuda()

acknowledgement

The entire code is based on RepVGG repository. Thanks for simply-using code.

mobileone's People

Contributors

shoutoutyangjie avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

mobileone's Issues

Validation loss : deployed vs full model

Hi,
Thank you very much for the implementation.
I would like to know if you still have the graphs of the full model validation loss and the inference (ready to deploy) model validation loss.
I would like to know how both models behaved in the very early stage of training (first 20 epochs).
Thank you very much for considering my request :)

question about merge conv and bn

hi,when i read your code,i find that a question about merge conv and bn.
when merge conv and bn:
image
why your code :
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std

why you not use conv's bias, i think the result should be:
weight = conv_weight * bn_weight.view(out_channels, 1, 1, 1) / bn_std.view(out_channels, 1, 1, 1)

     bias = bn_weight * (conv_bias - bn_mean) / bn_std + bn_bias

reparameterize requestion

image

I like your work very much, it gave me a lot of inspiration, because I am a beginner, some codes do not understand, I would like to ask, in this part of the code, I commented "model = copy.deepcopy(model)" Does this line of code affect the final result? Why can't reparameterization be done directly

推理速度

大佬有在iphone上实测推理速度吗?

[BUG] Skip connection should be always added to DW branches.

Hi @shoutOutYangJie, I've noticed that you have a bug in your MobileOne implementation:

self.dw_bn_layer = nn.BatchNorm2d(in_channels) if out_channels == in_channels and stride == 1 else None

You are adding DW BN layer based on in_channels == out_channels condition but it should be always added because for DW part input channels are always equal to output channels. This condition should be only checked for PW part as there might be channel change.

Could you please provide dataset format?

Hello,I wanna train my model,but I don't know the format of the dataset.
Could you please provide it?Thank you.
The project brought me a lot of help,I really need it.
Please!

why use normal stem for self.stage0?

self.stage0 look like normal stem, rather than mobileOne block in paper. Any reason why?

By the way, use mobileOne block as stem might cause huge acc drop in cifar100 (50%+)

About mobileone_s2

I have no idea about the parameters of mobileone_s2. Could you help offer the parameters to me?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.