Git Product home page Git Product logo

Comments (5)

lukemetz avatar lukemetz commented on April 28, 2024

Hmm interesting. I would have suspected some memory overhead, but not nearly this much! Thank you for the carefully written issue. Sadly though, I am also not able to reproduce this on my machine. Could you please describe your hardware / setup?

Just to confirm this is CPU / host memory you are talking about correct? Not gpu?

FYI: There is some buffering going on with the "prefetch_batches: int = 300", but 8x8 images this would mean: 884 (4 splits)*128(batchsize)*300 bytes, or 10mb.... So it is not this.

from learned_optimization.

createmomo avatar createmomo commented on April 28, 2024

Hi Luke, thank you very much for your quick response. I hope the following details can be helpful.

I am using GPU

  • Driver Version: 470.82.01
  • Cuda Version: 11.4
  • NVIDIA GeForce RTX 3060

Python

  • 3.8
  • tensorflow: 2.7.0
  • jax: 0.2.27 (this is how I install jax to support gpu: pip install --upgrade "jax[cuda114]" jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html)

More Details:
The first time when I found the large GPU issue is when I run the pes.py example.
python pes.py --train_log_dir somedir


I understand there are some prefetch batches. But as you calculated, the pre-fetched data should be very small.


Again, thank you for your reply. I am also still investigating this issue and will let you know once I found something.

from learned_optimization.

lukemetz avatar lukemetz commented on April 28, 2024

Thanks for the info and being such an early tester!

Just to confirm that is NOT gpu memory, but an explosion in host (CPU) memory?

If you are observing GPU memory, see: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html for flags on how to turn that off.

Do you still see memory increases if you turn off the GPU? e.g. something like:

CUDA_VISIBLE_DEVICES= python pes.py --train_log_dir ....

from learned_optimization.

createmomo avatar createmomo commented on April 28, 2024

Hello, thank you so much! Your comment is really very helpful, especially the jax gpu memory allocation link.

Just to confirm that is NOT gpu memory, but an explosion in host (CPU) memory?

In my case, it was GPU explosion.

If you are observing GPU memory, see: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html for flags on how to turn that off.

For my case, I finally managed to reduce the GPU memory (from 10G+ to ~700M for running the pes.py) based on the suggestions on the above link.

What I did were:

  • disable Tensorflow to use GPU, because we just use Tensorflow to load the fashion_mnist dataset. I added such line in base.py
tf.config.experimental.set_visible_devices([], "GPU")
  • disable jax GPU memory preallocation behaviour. Change the linux environment variable:
export XLA_PYTHON_CLIENT_PREALLOCATE=false

The link https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html provided different options to avoid GPU OOM error, so there could be other solutions.

from learned_optimization.

lukemetz avatar lukemetz commented on April 28, 2024

Ahh tf also tries to also grab the GPU. That is annoying. I should fix that on my end. Going to make an issue. Thanks for posting your solution here!

from learned_optimization.

Related Issues (16)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.