Comments (5)
Hmm interesting. I would have suspected some memory overhead, but not nearly this much! Thank you for the carefully written issue. Sadly though, I am also not able to reproduce this on my machine. Could you please describe your hardware / setup?
Just to confirm this is CPU / host memory you are talking about correct? Not gpu?
FYI: There is some buffering going on with the "prefetch_batches: int = 300", but 8x8 images this would mean: 884 (4 splits)*128(batchsize)*300 bytes, or 10mb.... So it is not this.
from learned_optimization.
Hi Luke, thank you very much for your quick response. I hope the following details can be helpful.
I am using GPU
- Driver Version: 470.82.01
- Cuda Version: 11.4
- NVIDIA GeForce RTX 3060
Python
- 3.8
- tensorflow: 2.7.0
- jax: 0.2.27 (this is how I install jax to support gpu:
pip install --upgrade "jax[cuda114]" jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html
)
More Details:
The first time when I found the large GPU issue is when I run the pes.py example.
python pes.py --train_log_dir somedir
I understand there are some prefetch batches. But as you calculated, the pre-fetched data should be very small.
Again, thank you for your reply. I am also still investigating this issue and will let you know once I found something.
from learned_optimization.
Thanks for the info and being such an early tester!
Just to confirm that is NOT gpu memory, but an explosion in host (CPU) memory?
If you are observing GPU memory, see: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html for flags on how to turn that off.
Do you still see memory increases if you turn off the GPU? e.g. something like:
CUDA_VISIBLE_DEVICES= python pes.py --train_log_dir ....
from learned_optimization.
Hello, thank you so much! Your comment is really very helpful, especially the jax gpu memory allocation link.
Just to confirm that is NOT gpu memory, but an explosion in host (CPU) memory?
In my case, it was GPU explosion.
If you are observing GPU memory, see: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html for flags on how to turn that off.
For my case, I finally managed to reduce the GPU memory (from 10G+ to ~700M for running the pes.py) based on the suggestions on the above link.
What I did were:
- disable Tensorflow to use GPU, because we just use Tensorflow to load the fashion_mnist dataset. I added such line in base.py
tf.config.experimental.set_visible_devices([], "GPU")
- disable jax GPU memory preallocation behaviour. Change the linux environment variable:
export XLA_PYTHON_CLIENT_PREALLOCATE=false
The link https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html provided different options to avoid GPU OOM error, so there could be other solutions.
from learned_optimization.
Ahh tf also tries to also grab the GPU. That is annoying. I should fix that on my end. Going to make an issue. Thanks for posting your solution here!
from learned_optimization.
Related Issues (16)
- Notebook Not Found HOT 2
- TF uses GPU in tf.data datasets.
- Error while runing image_test.py HOT 1
- PyTorch port? HOT 2
- Use with Tensorflow JS?
- Colab link not working
- License of checkpoints
- Understanding the differences compared to your earlier work & library
- Keras integration
- jnp.sign(mean_rms) is always 1
- colab demo error HOT 1
- pytorch implementation?
- typo in the tutorial
- Issue with the Demo_for_training_a_model_with_a_learned_optimizer.ipynb
- Wrong implementation of hyper_v2 mix_layers
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from learned_optimization.