Comments (19)
We should not be initializing CUDA upon import.
Yes, totally agree, but I also don't see a cuInit
here (unless I'm using a wrong PyTorch build).
We've recently seen a real CUDA init call while importing PyTorch in #116276
and worked on a unit test making sure this won't happen again.
CC @malfet, @Aidyn-A, @atalman as the test laded in #117043
Running the script from this issue in a debugger also does not show any cuInit
calls:
lldb --batch -o "b cuInit" -o "run" -- python3 -c "import torch;print(torch.cuda.device_count() > 0); import torch._higher_order_ops.map"
while adding ; torch.randn(1).cuda()
indeed breaks.
The script tries to set CUDA_VISIBLE_DEVICES
after torch
was imported which is too late.
This env variable should be set before importing any CUDA-enabled libraries in the script or exported in your terminal.
Nevertheless, we might want to still delay the device count call to keep import torch
lightweight.
from pytorch.
sure, I can have a look soon @ezyang
from pytorch.
Oh. So if we get rid of the LRU cache when CUDA is not initialized, that should fix this once and for all. I'll send a patch.
from pytorch.
Yeah, your patch is still good, I still want it in
from pytorch.
Marking with the pt2 oncall since most of the relevant files are PT2 related
from pytorch.
We should not be initializing CUDA upon import.
from pytorch.
Hi @ezyang ,
Thanks for your quick response here.
Could you please help connect the related engineer to help fix it?
We are updating the latest PyTorch for our products https://github.com/Project-MONAI/MONAI and https://catalog.ngc.nvidia.com/orgs/nvidia/teams/clara/containers/monai-toolkit.
It's a blocker for the release.
Thanks in advance.
from pytorch.
Could it be this? #121864
from pytorch.
looking into this further, I think it's related to the device_count
caching issue: #95073,
more recently this module-level device_count call seems to make the issue more noticeable:
pytorch/torch/_dynamo/device_interface.py
Line 198 in 4c70ab2
from pytorch.
@wyli Voz isn't at Meta anymore, do you think you could send a patch to make this lazy or something
from pytorch.
put a possible fix here in case it's useful #122795
from pytorch.
I think there is a notion of initialization which is weaker than cuInit, but it calls some CUDA API calls like torch.cuda.device_count
which burn in the device count based on the setting of the envvar here. So we are probably testing the cuInit case, but NOT this case (which is more pernicious, because AFAICT the only way you can notice this happens is because setenv CUDA_VISIBLE_DEVICES stops working)
from pytorch.
Actually, reading over our code, this may be an entirely PyTorch inflicted problem
DeviceIndex device_count() noexcept {
// initialize number of devices only once
static int count = []() {
try {
auto result = device_count_impl(/*fail_if_no_driver=*/false);
TORCH_INTERNAL_ASSERT(
result <= std::numeric_limits<DeviceIndex>::max(),
"Too many CUDA devices, DeviceIndex overflowed");
return result;
} catch (const c10::Error& ex) {
// We don't want to fail, but still log the warning
// msg() returns the message without the stack trace
TORCH_WARN("CUDA initialization: ", ex.msg());
return 0;
}
}();
return static_cast<DeviceIndex>(count);
}
from pytorch.
I thought maybe this would work #122805 but it doesn't >:(
(/home/ezyang/local/b/pytorch-env) [[email protected] ~/local/b/pytorch (8743ac67)]$ python wu.py
8
8
Traceback (most recent call last):
File "/data/users/ezyang/b/pytorch/torch/cuda/__init__.py", line 306, in _lazy_init
queued_call()
File "/data/users/ezyang/b/pytorch/torch/cuda/__init__.py", line 174, in _check_capability
capability = get_device_capability(d)
File "/data/users/ezyang/b/pytorch/torch/cuda/__init__.py", line 430, in get_device_capability
prop = get_device_properties(device)
File "/data/users/ezyang/b/pytorch/torch/cuda/__init__.py", line 448, in get_device_properties
return _get_device_properties(device) # type: ignore[name-defined]
RuntimeError: device=, num_gpus=
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/data/users/ezyang/b/pytorch/wu.py", line 6, in <module>
print(torch.empty(10, device='cuda'))
File "/data/users/ezyang/b/pytorch/torch/cuda/__init__.py", line 312, in _lazy_init
raise DeferredCudaCallError(msg) from e
torch.cuda.DeferredCudaCallError: CUDA call failed lazily at initialization with error: device=, num_gpus=
CUDA call was originally invoked at:
File "/data/users/ezyang/b/pytorch/wu.py", line 1, in <module>
import torch
File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 883, in exec_module
File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
File "/data/users/ezyang/b/pytorch/torch/__init__.py", line 1485, in <module>
_C._initExtension(manager_path())
File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 883, in exec_module
File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
File "/data/users/ezyang/b/pytorch/torch/cuda/__init__.py", line 238, in <module>
_lazy_call(_check_capability)
File "/data/users/ezyang/b/pytorch/torch/cuda/__init__.py", line 235, in _lazy_call
_queued_calls.append((callable, traceback.format_stack()))
(/home/ezyang/local/b/pytorch-env) [[email protected] ~/local/b/pytorch (8743ac67)]$ cat wu.py
import torch
import os
print(torch.cuda.device_count())
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
print(torch.cuda.device_count())
print(torch.empty(10, device='cuda'))
Maybe someone can figure out what I did wrong lol.
from pytorch.
Oh, it's because we are using nvml to query device count 🤔
from pytorch.
So calling device_count initializes NVML. Is that the root cause of the problem here? Should we stop using NVML to query device here?
from pytorch.
Hi @ezyang, I think the title of this ticket becomes a bit misleading after all the investigations.
to summarize my understanding of the issue: when importing some torch submodules such as import torch._dynamo
, somewhat expectedly torch.cuda.device_count
will be called
pytorch/torch/_dynamo/device_interface.py
Line 198 in 37e3c8f
and the outcome cached
pytorch/torch/cuda/__init__.py
Line 742 in 37e3c8f
So, any setenv CUDA_VISIBLE_DEVICES after the import statement may make the actual device count inconsistent with the (cached) output of torch.cuda.device_count()
...
from pytorch.
looking at the implementation
pytorch/torch/cuda/__init__.py
Lines 748 to 749 in 37e3c8f
you're right -- in our use case the torch.cuda.device_count
calls _device_count_nvml
and it doesn't seem to allocate any gpu mem. if it runs into torch._C._cuda_getDeviceCount()
, it indeed weakly init and allocate a small gpu memory.
from pytorch.
agreed (if performance is not a concern here), thanks for looking into the issue.
also this module-level torch.cuda.device_count
could be delayed as well
pytorch/torch/_dynamo/device_interface.py
Line 198 in 37e3c8f
from pytorch.
Related Issues (20)
- FakeTensor support of pin_memory
- DISABLED test_mnist (__main__.TORCH_NN_MODULE) HOT 1
- Setting output_size argument during forward breaks LazyConvTranspose HOT 1
- linear model cannot calculate grads correctly on device MPS HOT 3
- mmap error - no such device (19) HOT 1
- Rename ort to maia HOT 1
- DISABLED test_pre_dispatch_export_auto_functionalize_simple_cuda_float32 (__main__.TestHOPCUDA) HOT 4
- Quantization - RuntimeError: apply_dynamic is not implemented for this packed parameter type
- FSDP + DTensor is not working with `SHARD_GRAD_OP` + use_orig_params HOT 9
- Ensure that each flaky test is retried once HOT 1
- segfault when loading AOT .pt2 model twice HOT 1
- We should walk torch._C and require everything have type annotations HOT 4
- running Kornia.morphology.opening and trying to export onnx
- Dynamo full graph export fails on 'print' and 'logger' HOT 5
- DISABLED test_dtensor_op_db_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestDTensorOpsCPU) HOT 2
- DISABLED test_shufflenet_v2 (__main__.TORCH_NN_MODULE) HOT 2
- Followup on Google Collab issue for release 2.3 HOT 2
- compile creates FakeTensors with dynamic shapes even when dynamic=False when inputs are views HOT 2
- [Dynamo] Unsupported: missing: DELETE_SUBSCR HOT 5
- DISABLED test_buffer_mutation_3_non_abi_compatible_cuda (__main__.AOTInductorTestNonABICompatibleCuda) HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.