Git Product home page Git Product logo

Comments (19)

ptrblck avatar ptrblck commented on May 10, 2024 1

@ezyang

We should not be initializing CUDA upon import.

Yes, totally agree, but I also don't see a cuInit here (unless I'm using a wrong PyTorch build).

We've recently seen a real CUDA init call while importing PyTorch in #116276
and worked on a unit test making sure this won't happen again.
CC @malfet, @Aidyn-A, @atalman as the test laded in #117043

Running the script from this issue in a debugger also does not show any cuInit calls:

lldb --batch -o "b cuInit" -o "run" -- python3 -c "import torch;print(torch.cuda.device_count() > 0); import torch._higher_order_ops.map"

while adding ; torch.randn(1).cuda() indeed breaks.

The script tries to set CUDA_VISIBLE_DEVICES after torch was imported which is too late.
This env variable should be set before importing any CUDA-enabled libraries in the script or exported in your terminal.

Nevertheless, we might want to still delay the device count call to keep import torch lightweight.

from pytorch.

wyli avatar wyli commented on May 10, 2024 1

sure, I can have a look soon @ezyang

from pytorch.

ezyang avatar ezyang commented on May 10, 2024 1

Oh. So if we get rid of the LRU cache when CUDA is not initialized, that should fix this once and for all. I'll send a patch.

from pytorch.

ezyang avatar ezyang commented on May 10, 2024 1

Yeah, your patch is still good, I still want it in

from pytorch.

bdhirsh avatar bdhirsh commented on May 10, 2024

Marking with the pt2 oncall since most of the relevant files are PT2 related

from pytorch.

ezyang avatar ezyang commented on May 10, 2024

We should not be initializing CUDA upon import.

from pytorch.

Nic-Ma avatar Nic-Ma commented on May 10, 2024

Hi @ezyang ,

Thanks for your quick response here.
Could you please help connect the related engineer to help fix it?
We are updating the latest PyTorch for our products https://github.com/Project-MONAI/MONAI and https://catalog.ngc.nvidia.com/orgs/nvidia/teams/clara/containers/monai-toolkit.
It's a blocker for the release.

Thanks in advance.

from pytorch.

Chillee avatar Chillee commented on May 10, 2024

Could it be this? #121864

from pytorch.

wyli avatar wyli commented on May 10, 2024

looking into this further, I think it's related to the device_count caching issue: #95073,

more recently this module-level device_count call seems to make the issue more noticeable:

for i in range(torch.cuda.device_count()):
(cc @voznesenskym #117386)

from pytorch.

ezyang avatar ezyang commented on May 10, 2024

@wyli Voz isn't at Meta anymore, do you think you could send a patch to make this lazy or something

from pytorch.

wyli avatar wyli commented on May 10, 2024

put a possible fix here in case it's useful #122795

from pytorch.

ezyang avatar ezyang commented on May 10, 2024

I think there is a notion of initialization which is weaker than cuInit, but it calls some CUDA API calls like torch.cuda.device_count which burn in the device count based on the setting of the envvar here. So we are probably testing the cuInit case, but NOT this case (which is more pernicious, because AFAICT the only way you can notice this happens is because setenv CUDA_VISIBLE_DEVICES stops working)

from pytorch.

ezyang avatar ezyang commented on May 10, 2024

Actually, reading over our code, this may be an entirely PyTorch inflicted problem

DeviceIndex device_count() noexcept {
  // initialize number of devices only once
  static int count = []() {
    try {
      auto result = device_count_impl(/*fail_if_no_driver=*/false);
      TORCH_INTERNAL_ASSERT(
          result <= std::numeric_limits<DeviceIndex>::max(),
          "Too many CUDA devices, DeviceIndex overflowed");
      return result;
    } catch (const c10::Error& ex) {
      // We don't want to fail, but still log the warning
      // msg() returns the message without the stack trace
      TORCH_WARN("CUDA initialization: ", ex.msg());
      return 0;
    }     
  }();    
  return static_cast<DeviceIndex>(count);
}

from pytorch.

ezyang avatar ezyang commented on May 10, 2024

I thought maybe this would work #122805 but it doesn't >:(

(/home/ezyang/local/b/pytorch-env) [[email protected] ~/local/b/pytorch (8743ac67)]$ python wu.py
8
8
Traceback (most recent call last):
  File "/data/users/ezyang/b/pytorch/torch/cuda/__init__.py", line 306, in _lazy_init
    queued_call()
  File "/data/users/ezyang/b/pytorch/torch/cuda/__init__.py", line 174, in _check_capability
    capability = get_device_capability(d)
  File "/data/users/ezyang/b/pytorch/torch/cuda/__init__.py", line 430, in get_device_capability
    prop = get_device_properties(device)
  File "/data/users/ezyang/b/pytorch/torch/cuda/__init__.py", line 448, in get_device_properties
    return _get_device_properties(device)  # type: ignore[name-defined]
RuntimeError: device=, num_gpus=

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/data/users/ezyang/b/pytorch/wu.py", line 6, in <module>
    print(torch.empty(10, device='cuda'))
  File "/data/users/ezyang/b/pytorch/torch/cuda/__init__.py", line 312, in _lazy_init
    raise DeferredCudaCallError(msg) from e
torch.cuda.DeferredCudaCallError: CUDA call failed lazily at initialization with error: device=, num_gpus=

CUDA call was originally invoked at:

  File "/data/users/ezyang/b/pytorch/wu.py", line 1, in <module>
    import torch
  File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 883, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/data/users/ezyang/b/pytorch/torch/__init__.py", line 1485, in <module>
    _C._initExtension(manager_path())
  File "<frozen importlib._bootstrap>", line 1027, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1006, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 883, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/data/users/ezyang/b/pytorch/torch/cuda/__init__.py", line 238, in <module>
    _lazy_call(_check_capability)
  File "/data/users/ezyang/b/pytorch/torch/cuda/__init__.py", line 235, in _lazy_call
    _queued_calls.append((callable, traceback.format_stack()))
(/home/ezyang/local/b/pytorch-env) [[email protected] ~/local/b/pytorch (8743ac67)]$ cat wu.py
import torch
import os
print(torch.cuda.device_count())
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
print(torch.cuda.device_count())
print(torch.empty(10, device='cuda'))

Maybe someone can figure out what I did wrong lol.

from pytorch.

ezyang avatar ezyang commented on May 10, 2024

Oh, it's because we are using nvml to query device count 🤔

from pytorch.

ezyang avatar ezyang commented on May 10, 2024

So calling device_count initializes NVML. Is that the root cause of the problem here? Should we stop using NVML to query device here?

from pytorch.

wyli avatar wyli commented on May 10, 2024

Hi @ezyang, I think the title of this ticket becomes a bit misleading after all the investigations.

to summarize my understanding of the issue: when importing some torch submodules such as import torch._dynamo, somewhat expectedly torch.cuda.device_count will be called

for i in range(torch.cuda.device_count()):

and the outcome cached
@lru_cache(maxsize=1)

So, any setenv CUDA_VISIBLE_DEVICES after the import statement may make the actual device count inconsistent with the (cached) output of torch.cuda.device_count()...

from pytorch.

wyli avatar wyli commented on May 10, 2024

looking at the implementation

nvml_count = -1 if torch.version.hip else _device_count_nvml()
return torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count

you're right -- in our use case the torch.cuda.device_count calls _device_count_nvml and it doesn't seem to allocate any gpu mem. if it runs into torch._C._cuda_getDeviceCount(), it indeed weakly init and allocate a small gpu memory.

from pytorch.

wyli avatar wyli commented on May 10, 2024

agreed (if performance is not a concern here), thanks for looking into the issue.

also this module-level torch.cuda.device_count could be delayed as well

for i in range(torch.cuda.device_count()):

from pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.