Comments (7)
Interestingly, after 3 to 4 iterations, the memory usage stabilizes
Yeah, this sounds like normal allocator behavior. You can look into https://www.gnu.org/software/libc/manual/html_node/Memory-Allocation-Tunables.html, but I think this is probably:
- Outside the scope of PyTorch
- Unrelated to the GPU-CPU transfers.
You can also try some other allocator like https://jemalloc.net/, but you'll probably see similar behavior.
from pytorch.
I also had this problem, does anyone have any solution or idea how to solve it?
from pytorch.
cc @albanD
from pytorch.
Quick observation:
- The memory is always stable between the with no_grad and last line of the function.
- I'm not sure how to interpret on line have 452.1 and the next 333.8MB but the increment is 0 ??
Also on cpu, the libc malloc is used for memory allocation. Depending on which memory measure you're looking at, there are quite a few cases where the allocator will keep memory around and not give it back to the OS. This would explain the first few iterations allocating more until malloc starts to properly re-use memory it has cached.
from pytorch.
Thanks for the reply. I'll try to explain a bit better.
- The memory is always stable between the with no_grad and last line of the function.
Indeed, the memory remains stable between the with torch.no_grad()
context manager and the final line of the function. However, this stability does not necessarily correlate with the amount of memory available at the start of the function.
To illustrate the memory usage throughout the function, I will add an additional output with a print statement to highlight the differences:
PyTorch version: 2.2.2+cu121
******* Iteration num: 1 ***********
Start "with torch_no_grad()"
End "with torch_no_grad()"
End run_test()
Filename: test.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
8 332.1 MiB 332.1 MiB 1 @profile
9 def run_test():
10
11 332.1 MiB 0.0 MiB 1 print('Start "with torch_no_grad()"')
12
13 449.8 MiB 0.0 MiB 2 with torch.no_grad():
14 332.1 MiB 0.0 MiB 1 batch_size = 300
15 332.1 MiB 0.0 MiB 1 tensor_size = (1000, 1000)
16 # Create a batch tensor in one line
17 1479.1 MiB 1147.0 MiB 303 batch_tensors = torch.stack([torch.randn(tensor_size) for _ in range(batch_size)]).to('cpu')
18
19 449.7 MiB -1029.5 MiB 1 batch_tensors = batch_tensors.to('cuda')
20 1594.1 MiB 1144.4 MiB 1 batch_tensors = batch_tensors.to('cpu').detach()
21 # Print the size of the batch tensor
22 449.8 MiB -1144.3 MiB 1 del batch_tensors
23
24 449.8 MiB 0.0 MiB 1 print('End "with torch_no_grad()"')
25
26 449.8 MiB 0.0 MiB 1 gc.collect()
27 449.8 MiB 0.0 MiB 1 torch.cuda.empty_cache()
28
29 449.8 MiB 0.0 MiB 1 print('End run_test()')
******* Iteration num: 2 ***********
Start "with torch_no_grad()"
End "with torch_no_grad()"
End run_test()
Filename: test.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
8 449.8 MiB 449.8 MiB 1 @profile
9 def run_test():
10
11 449.8 MiB 0.0 MiB 1 print('Start "with torch_no_grad()"')
12
13 1594.5 MiB 0.0 MiB 2 with torch.no_grad():
14 449.8 MiB 0.0 MiB 1 batch_size = 300
15 449.8 MiB 0.0 MiB 1 tensor_size = (1000, 1000)
16 # Create a batch tensor in one line
17 2738.7 MiB 2288.9 MiB 303 batch_tensors = torch.stack([torch.randn(tensor_size) for _ in range(batch_size)]).to('cpu')
18
19 1594.5 MiB -1144.2 MiB 1 batch_tensors = batch_tensors.to('cuda')
20 2738.7 MiB 1144.2 MiB 1 batch_tensors = batch_tensors.to('cpu').detach()
21 # Print the size of the batch tensor
22 1594.5 MiB -1144.2 MiB 1 del batch_tensors
23
24 1594.5 MiB 0.0 MiB 1 print('End "with torch_no_grad()"')
25
26 1594.5 MiB 0.0 MiB 1 gc.collect()
27 1594.5 MiB 0.0 MiB 1 torch.cuda.empty_cache()
28
29 1594.5 MiB 0.0 MiB 1 print('End run_test()')
- I'm not sure how to interpret on line have 452.1 and the next 333.8MB but the increment is 0 ??
The line with 452.1MB likely represents the memory allocated at the conclusion of the with torch.no_grad()
block.
Also on cpu, the libc malloc is used for memory allocation. Depending on which memory measure you're looking at, there are quite a few cases where the allocator will keep memory around and not give it back to the OS. This would explain the first few iterations allocating more until malloc starts to properly re-use memory it has cached.
Is there a way to prevent this behavior, especially considering that in the second iteration, memory usage increases excessively?
Would you recommend exploring alternative memory allocator implementations as a potential solution?
from pytorch.
Is there a way to prevent this behavior, especially considering that in the second iteration, memory usage increases excessively?
You can try using another malloc implementation like jemalloc but they will most likely have similar behavior.
In particular, as long as there is no memory pressure, it is usually faster to keep around memory as you can serve it faster.
In particular, unless you actually seem OOMs, it might just be keeping memory around to speed things up.
from pytorch.
Is there a way to prevent this behavior, especially considering that in the second iteration, memory usage increases excessively?
You can try using another malloc implementation like jemalloc but they will most likely have similar behavior. In particular, as long as there is no memory pressure, it is usually faster to keep around memory as you can serve it faster.
In particular, unless you actually seem OOMs, it might just be keeping memory around to speed things up.
Thanks for the answer. I'll experiment with different malloc implementations to see if the behavior persists.
My main concern is that this issue also occurs when loading and transferring models from CPU to GPU. I'm encountering out-of-memory errors. It seems strange that the model loads successfully the first few times, but then requires significantly more memory on subsequent attempts.
from pytorch.
Related Issues (20)
- DISABLED test_view_and_inplace_view (__main__.TestAOTAutograd) HOT 1
- [inductor][cpu]mobilenet_v2_quantized_qat float32 single thread static/dynamic shape CPP/default wrapper performance regression in 2024-04-28 nightly release HOT 1
- [BUG]Nan in gradients of scaled_dot_product_attention operation with mem_efficient backend
- Unnecessary warning when numpy not installed
- [RFC] Add Cpp Template for GEMM related ops via max-autotune for Inductor CPU
- MAX-Autotune Compilation Time Regression Due To Added MM Configs HOT 1
- [Dynamo] Support tracing through _get_current_dispatch_mode_stack
- Have config/env option to disable all PT2 caching
- [dynamo] fix nn.Module @property that accesses closure cells
- KINETO_USE_DAEMON causing issues
- `torch.compile` and complex numbers HOT 3
- Support dynamo tracing weakref obj
- Migrate multiple/custom runner labels before deprecation
- torch._inductor.config.max_autotune_gemm_backends = "TRITON" crashes with Convolution layer
- ☂️ `torch.compile` generates slower code for LLMs than eager on ARM platform (M1/AARCH64)
- [ARM] `Vectorized<half>::loadu(x, 8)` yields slow code if `-fno-unsafe-math-optimizations` are used HOT 3
- [FSDP2] _sharded_param_data is sitll on meta while sharded_param moved to cuda after calling initialize_parameters() HOT 2
- [Distributed Checkpoint] When loading FSDP sharded checkpointing each rank needs all the checkpointing files HOT 1
- [DTensor][Tensor Parallel] transformer test numerical issue when `dtype=torch.float32`
- Improve oneDNN memory alloction performance for pytorch Windows HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.