Git Product home page Git Product logo

ds-net's Introduction

Dynamic Slimmable Network (DS-Net)

This repository contains PyTorch code of our paper: Dynamic Slimmable Network (CVPR 2021 Oral).

image

Architecture of DS-Net. The width of each supernet stage is adjusted adaptively by the slimming ratio ρ predicted by the gate.

image

Accuracy vs. complexity on ImageNet.

Pretrained Supernet

  • Supernet Checkpoint

  • Here is a summary of sub-networks performance of the pretrained supernet:

    Subnetwork 0 1 2 3 4 5 6 7 8 9 10 11 12 13
    MAdds 133M 153M 175M 200M 226M 255M 286M 319M 355M 393M 433M 475M 519M 565M
    Top-1 (%) 70.1 70.4 70.8 71.2 71.6 72.0 72.4 72.7 73.0 73.3 73.6 73.9 74.1 74.6
    Top-5 (%) 89.4 89.6 89.9 90.2 90.3 90.6 90.9 91.0 91.2 91.4 91.5 91.7 91.8 92.0

Usage

1. Requirements

2. Stage I: Supernet Training

For example, train dynamic slimmable MobileNet supernet with 8 GPUs (takes about 2 days):

python -m torch.distributed.launch --nproc_per_node=8 train.py /PATH/TO/ImageNet -c ./configs/mobilenetv1_bn_uniform.yml

3. Stage II: Gate Training

  • Modify resume: in configs/mobilenetv1_bn_uniform_reset_bn.yml to your supernet checkpoint. Recalibrate BN before gate training

    python -m torch.distributed.launch --nproc_per_node=8 train.py /PATH/TO/ImageNet -c ./configs/mobilenetv1_bn_uniform_reset_bn.yml
    
  • Modify resume: in configs/mobilenetv1_bn_uniform_gate.yml to your supernet checkpoint after BN recalibration or our pretrained Supernet Checkpoint. Start gate training

    python -m torch.distributed.launch --nproc_per_node=8 train.py /PATH/TO/ImageNet -c ./configs/mobilenetv1_bn_uniform_gate.yml
    

Citation

If you use our code for your paper, please cite:

@inproceedings{li2021dynamic,
  author = {Changlin Li and
            Guangrun Wang and
            Bing Wang and
            Xiaodan Liang and
            Zhihui Li and
            Xiaojun Chang},
  title = {Dynamic Slimmable Network},
  booktitle = {CVPR},
  year = {2021}
}

ds-net's People

Contributors

changlin31 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ds-net's Issues

Pretrained models

Hi, is it possible to release some of the pertained models? Thank you!

Why the num_choice in different yml is different?

Why you set num_choice in mobilenetv1_bn_uniform_reset_bn.yml as 4, but set this parameter as 14 in the other two yml file?

老哥,如果你也是**人,咱们还是用中文交流吧,我英语水平比较感人。。。

why not set ensemble_ib to True?

Hi,

I found that ensemble_ib is set to False for both slim training and gate training from the configs, but from paper it would boost the performance when set toTrue.

Any idea?

Dynamic path for DS-mobilenet

Hi. Thanks for your work. I am reading your paper and trying to reimplement, and I feel confused about some details.
You mentioned in your paper that the slimming ratio ρ∈[0.35 : 0.05 : 1.25], which have 18 paths.
However, in your code, there are only 14 paths ρ∈[0.35 : 0.05 : 1] as mentioned in

[list(range(736, 1152 + 1, 32)), 2, 3, 2, 'ds', False],

. And also, when conducting gate training, the gate function only has a 4-dimension output, meaning that there is only 4 paths and the slimming ratio is restricted to ρ∈[0.35 : 0.05 : 0.5].
channel_gate_num=4 if has_gate else 0)

Why the dynamic path for larger network is not used?

Object Detection

can you please show us the code with objection detection?thank you!

DS-Net for object detection

Hello. Thanks for your work. I noticed that you also conducted some experiments in object detection. I wonder whether or when you will release the code

Question about calculating MAdds of dynamic network in the paper

Thank you for your great work, and I have a question about how to calculate MAdds in your paper.
The dynamic network has different widths and MAdds for each instance, but you denoted MAdds for your networks.
Are they the average MAdds for the whole dataset?

Actual acceleration on Resnet

Thank you for your great work! I have a question about the latency. Could the method achieve actual acceleration on Resnet?

The usage of gumbel softmax in DS-Net

Thank you for your very nice work,I want to know that the effect of gumble softmax,because I think the network can be trained without gumble softmax.
Is the gumbel softmax just aimed to increase the randomness of channel choice?

Error of change the num_choice in mobilenetv1_bn_uniform_reset_bn.yml

I follow your suggestion to set the num_choice in mobilenetv1_bn_uniform_reset_bn.yml to 14, but get an expected error when I use python -m torch.distributed.launch --nproc_per_node=8 train.py /PATH/TO/ImageNet -c ./configs/mobilenetv1_bn_uniform_reset_bn.yml.

08/25 10:15:57 AM Recalibrating BatchNorm statistics...
08/25 10:16:10 AM Finish recalibrating BatchNorm statistics.
08/25 10:16:19 AM Finish recalibrating BatchNorm statistics.
08/25 10:16:21 AM Test: [ 0/0] Mode: 0 Time: 0.344 (0.344) Loss: 6.9204 (6.9204) Prec@1: 0.0000 ( 0.0000) Prec@5: 0.0000 ( 0.0000) Flops: 132890408 (132890408)
08/25 10:16:22 AM Test: [ 0/0] Mode: 1 Time: 0.406 (0.406) Loss: 6.9189 (6.9189) Prec@1: 0.0000 ( 0.0000) Prec@5: 0.0000 ( 0.0000) Flops: 152917440 (152917440)
08/25 10:16:22 AM Test: [ 0/0] Mode: 2 Time: 0.381 (0.381) Loss: 6.9187 (6.9187) Prec@1: 0.0000 ( 0.0000) Prec@5: 0.0000 ( 0.0000) Flops: 175152224 (175152224)
08/25 10:16:23 AM Test: [ 0/0] Mode: 3 Time: 0.389 (0.389) Loss: 6.9134 (6.9134) Prec@1: 0.0000 ( 0.0000) Prec@5: 0.0000 ( 0.0000) Flops: 199594752 (199594752)
Traceback (most recent call last):
File "train.py", line 658, in
main()
File "train.py", line 635, in main
eval_metrics.append(validate_slim(model,
File "/home/chauncey/PycharmProjects/DS-Net-main/dyn_slim/apis/train_slim.py", line 215, in validate_slim
output = model(input)
File "/home/chauncey/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/chauncey/PycharmProjects/DS-Net-main/dyn_slim/models/dyn_slim_net.py", line 191, in forward
x = self.forward_features(x)
File "/home/chauncey/PycharmProjects/DS-Net-main/dyn_slim/models/dyn_slim_net.py", line 178, in forward_features
x = stage(x)
File "/home/chauncey/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/chauncey/PycharmProjects/DS-Net-main/dyn_slim/models/dyn_slim_stages.py", line 48, in forward
x = self.first_block(x)
File "/home/chauncey/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/chauncey/PycharmProjects/DS-Net-main/dyn_slim/models/dyn_slim_blocks.py", line 240, in forward
x = self.conv_pw(x)
File "/home/chauncey/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/chauncey/PycharmProjects/DS-Net-main/dyn_slim/models/dyn_slim_ops.py", line 94, in forward
self.running_outc = self.out_channels_list[self.channel_choice]
IndexError: list index out of range

It looks like we should make some adjustment in other py files.

project environment

Hi,could you provide the environment for the project?I try to train the network with python=3.8 pytorch=1.7.1,cuda=10.2.Shortly after starting training,there's a RuntimeError: CUDA error: device-side assert triggered happened,and some other environment also lead to this error.I'm not sure whether the problem is caused by the difference of environment.

Can we futher improve autoalim without gate?

It is not easy to deploy gate operator with some other backends, like TensorRT.

So my question is can we futher improve autoalim without the dynamic gate when inference?Any ongoing work are doing this?

MAdds of Pretrained Supernet

Hi Changlin, your work is excellent. I have a question about the calculation of MAdds, in README.md the MAdds of Subnetwork 13 is 565M, but I think the MAdds of Subnetwork 13 should be 821M observed in my experiments, because the channel number of Subnetwork 13 is larger than the original MobileNetV1, and the original MobileNetV1 1.0's MAdds should be 565M. Looking forward to your reply.

UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.

Why I get an warning:
/home/chauncey/.local/lib/python3.8/site-packages/torchvision/transforms/functional.py:364: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum. warnings.warn(
when I use
python3 -m torch.distributed.launch --nproc_per_node=1 train.py ./imagenet -c ./configs/mobilenetv1_bn_uniform.yml

Softmax twice for SGS loss?

Dear authors, thanks for this nice work.

I wonder why the calculation of the SGS loss is using the softmaxed data rather than the logits, considering the PyTorch CrossEntropyLoss already contains a softmax inside.

g_loss = loss_fn(m.keep_gate, gate_target)

self.keep_gate, self.print_gate, self.print_idx = gumbel_softmax(channel_choice, dim=1, training=self.training)
self.channel_choice = self.print_gate, self.print_idx
else:
self.channel_choice = None
return x
def get_gate(self):
return self.channel_choice
def gumbel_softmax(logits, tau=1, hard=False, dim=1, training=True):
""" See `torch.nn.functional.gumbel_softmax()` """
# if training:
# gumbels = -torch.empty_like(logits,
# memory_format=torch.legacy_contiguous_format).exponential_().log() # ~Gumbel(0,1)
# gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
# # else:
# # gumbels = logits
# y_soft = gumbels.softmax(dim)
gumbels = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() # ~Gumbel(0,1)
gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
y_soft = gumbels.softmax(dim)
with torch.no_grad():
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
# **test**
# index = 0
# y_hard = torch.Tensor([1, 0, 0, 0]).repeat(logits.shape[0], 1).cuda()
ret = y_hard - y_soft.detach() + y_soft
return y_soft, ret, index

运行问题

请问大佬下面这个问题是为什么
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.


/root/anaconda3/envs/0108/lib/python3.6/site-packages/torchvision/io/image.py:11: UserWarning: Failed to load image Python extension: /root/anaconda3/envs/0108/lib/python3.6/site-packages/torchvision/image.so: undefined symbol: _ZNK3c106IValue23reportToTensorTypeErrorEv
warn(f"Failed to load image Python extension: {e}")
/root/anaconda3/envs/0108/lib/python3.6/site-packages/torchvision/io/image.py:11: UserWarning: Failed to load image Python extension: /root/anaconda3/envs/0108/lib/python3.6/site-packages/torchvision/image.so: undefined symbol: _ZNK3c106IValue23reportToTensorTypeErrorEv
warn(f"Failed to load image Python extension: {e}")
01/21 05:42:18 AM Added key: store_based_barrier_key:1 to store for rank: 1
01/21 05:42:18 AM Added key: store_based_barrier_key:1 to store for rank: 0
01/21 05:42:18 AM Training in distributed mode with multiple processes, 1 GPU per process. Process 0, total 2.
01/21 05:42:18 AM Training in distributed mode with multiple processes, 1 GPU per process. Process 1, total 2.
01/21 05:42:20 AM Model slimmable_mbnet_v1_bn_uniform created, param count: 7676204
01/21 05:42:20 AM Data processing configuration for current model + dataset:
01/21 05:42:20 AM input_size: (3, 224, 224)
01/21 05:42:20 AM interpolation: bicubic
01/21 05:42:20 AM mean: (0.485, 0.456, 0.406)
01/21 05:42:20 AM std: (0.229, 0.224, 0.225)
01/21 05:42:20 AM crop_pct: 0.875
01/21 05:42:20 AM NVIDIA APEX not installed. AMP off.
01/21 05:42:21 AM Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.
01/21 05:42:21 AM Scheduled epochs: 40
01/21 05:42:21 AM Training folder does not exist at: images/train
01/21 05:42:21 AM Training folder does not exist at: images/train
Killing subprocess 239
Killing subprocess 240
Traceback (most recent call last):
File "/root/anaconda3/envs/0108/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"main", mod_spec)
File "/root/anaconda3/envs/0108/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/root/anaconda3/envs/0108/lib/python3.6/site-packages/torch/distributed/launch.py", line 340, in
main()
File "/root/anaconda3/envs/0108/lib/python3.6/site-packages/torch/distributed/launch.py", line 326, in main
sigkill_handler(signal.SIGTERM, None) # not coming back
File "/root/anaconda3/envs/0108/lib/python3.6/site-packages/torch/distributed/launch.py", line 301, in sigkill_handler
raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd)
subprocess.CalledProcessError: Command '['/root/anaconda3/envs/0108/bin/python', '-u', 'train.py', '--local_rank=1', 'images', '-c', './configs/mobilenetv1_bn_uniform_reset_bn.yml']' returned non-zero exit status 1.

Commands to perfrom Inference

Hi authors,
Thanks for releasing the code and pre-trained model.
Could you provide a small script or some instructions to perform inference in a dynamic mode?
I am more interested in observing how each sample activates respective paths.

Thanks in advance!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.