Git Product home page Git Product logo

pytorch_memonger's Introduction

PyTorch Memory optimizations via gradient checkpointing

This repository contains implementation of various PyTorch models using the gradient checkpointing[1] which allows trading compute for memory and hence allows training bigger/wider models and use large minibatch sizes.

The application of checkpointing is showcased on various models:

  • ResNet
  • DenseNet
  • LSTM model from pytorch examples here
  • VNet model which is used in medical imaging applications, available here

Results of checkpointing on these models are showcased below:

In order to use the models, you need to install PyTorch master following instructions from here

To run checkpointed models and their baseline tests, follow the commands below:

# for checkpointed
python test_memory_optimized.py

# for baseline
python test_memory_optimized.py

Tutorial

We provide a tutorial to describe how to use checkpointing for various kinds of models.

There are few special kinds of layers like Batch normalization, dropout that should be handled carefully. The details for handling those are also available in the tutorial

References

[1]. Siskind, Jeffrey Mark, and Barak A. Pearlmutter. "Divide-and-Conquer Checkpointing for Arbitrary Programs with No User Annotation." arXiv preprint arXiv:1708.06799 (2017).

pytorch_memonger's People

Contributors

prigoyal avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch_memonger's Issues

checkpointing for model sharding

I split my model on two GPUs and use checkpointing to reduce memory usage, but I still got an OOM error when calling backward() , it seem like all the data is proccesed in one of my GPU when calling bakcward()? (I'm using gpu 2 and 3)
Here is my memory usage before backward().
1
Here is it after backward().
2
Any advice is appreciated.

Is checkpoint_sequential still in torch.utils ?

Hi,

Thank you for the great repository :)

I have quickly searched and wasn't able to locate the implementation of checkpoint_sequential in the master branch of the pytorch (which is specified as a dependency for your tutorial).
Is it still supposed to be there and I just somehow couldn't find it?

Thank you.

NameError: name 'self' is not defined

This is probably an issue on my part, but when I run the "Checkpointing_for_PyTorch_models" tutorial, when I get to self.assertEqual(out_checkpointed, out_not_checkpointed), get the following error NameError: name 'self' is not defined. Am I missing something?

Thanks in advance.

WLM test is failing

I am working with a 0.4.0a0 version and the WLM test is failing in repackage_hidden() function. The error message is copied below. I think it has something to do w/ the fact that the Variable API has been deprecated. Variable functionality works but the type() command returns torch.Tensor and not Variable.
I changed the repackage_hidden() to:
def repackage_hidden(self, h): """Wraps hidden states in new Variables, to detach them from their history.""" #if type(h) == Variable: if type(h) in [Variable, Tensor] : return Variable(h.data) else: return tuple(self.repackage_hidden(v) for v in h)

and it seem to work. Could you please check the issue and validate the fix?

=======================================================================
ERROR: test_wlm_baseline (main.TestMemoryBaseline)

.....
.....
.....
File "test_memory_baseline.py", line 219, in repackage_hidden
return tuple(self.repackage_hidden(v) for v in h)
File "/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/tensor.py", line 351, in iter
raise TypeError('iteration over a 0-d tensor')
TypeError: iteration over a 0-d tensor

benchmark variables are not clear

Hi Priya,

I was trying to reproduce your benchmarks results. I think two important information are missing. First, what is the GPU that you used to train these models and how much memory does it have? Second, I don't understand how you have computed max image and what it corresponds to. For example in the VNET, does it correspond to N?

Benchmark clarification

Hi Priya,
It is unclear how you run the benchmark. I noticed that for best results I need to run each unit-test independently. Otherwise the first tests seem to "steal" memory from later tests. If the benchmark is intended to run as a whole, probably need to find a way to clean the memory between the tests. Otherwise, maybe need to state clearly in the Readme that the tests are supposed to run independently.
Thanks
Dmitri

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.