Git Product home page Git Product logo

efficient_densenet_pytorch's Introduction

efficient_densenet_pytorch

A PyTorch implementation of DenseNets, optimized to save GPU memory.

Motivation

While DenseNets are fairly easy to implement in deep learning frameworks, most implmementations (such as the original) tend to be memory-hungry. In particular, the number of intermediate feature maps generated by batch normalization and concatenation operations grows quadratically with network depth. It is worth emphasizing that this is not a property inherent to DenseNets, but rather to the implementation.

This implementation uses a new strategy to reduce the memory consumption of DenseNets. We assign all intermediate feature maps to two shared memory allocations, which are utilized by every Batch Norm and concatenation operation. Because the data in these allocations are temporary, we re-populate the outputs during back-propagation. This adds 15-20% of time overhead for training, but reduces feature map consumption from quadratic to linear.

For more details, please see the technical report.

Diagram of implementation

Usage

Note:

This demo was initially developed on PyTorch v0.1.12. And for some (unknown) reasons, it cannot pass all the unit tests on PyTorch v0.2, but the performance (final accuracy) still remains the same :).

  • If you would like to help us improve it on v0.2, please trace this issue.

  • To downgrade PyTorch from v0.2 or higher, please refer to this instruction.

In your existing project: There are two files in the models folder.

  • models/densenet.py is a "naive" implementation, based off the torchvision and project killer implementations.
  • models/densenet_efficient.py is the new efficient implementation. (Code is still a little ugly. We're working on cleaning it up!) Copy either one of those files into your project!
  • models/densenet_efficient_multi_gpu.py is the new efficient implementation with multi-GPU support. They work as stand-alone files.

Running the demo:

  • single GPU:
CUDA_VISIBLE_DEVICES=0 python demo.py --efficient True --data <path_to_data_dir> --save <path_to_save_dir>
  • multi GPUs:
CUDA_VISIBLE_DEVICES=0,1,2,3 python demo.py --multi-gpu True --data <path_to_data_dir> --save <path_to_save_dir>

Options:

  • --depth (int) - depth of the network (number of convolution layers) (default 40)
  • --growth_rate (int) - number of features added per DenseNet layer (default 12)
  • --n_epochs (int) - number of epochs for training (default 300)
  • --batch_size (int) - size of minibatch (default 256)
  • --seed (int) - manually set the random seed (default None)

Performance

A comparison of the two implementations (each is a DenseNet-BC with 100 layers, batch size 64, tested on a NVIDIA Pascal Titan-X):

Implementation Memory cosumption (GB/GPU) Speed (sec/mini batch)
Naive 2.863 0.165
Efficient 1.605 0.207
Efficient (multi-GPU) 0.985 -

Other efficient implementations

efficient_densenet_pytorch's People

Contributors

gpleiss avatar taineleau avatar jiahuiyu avatar zhengrui avatar

Watchers

James Cloos avatar paper2code - bot avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.