Comments (7)
P sure this is caused by this line https://github.com/pytorch/pytorch/blob/main/torch/optim/adamw.py#L552.
To confirm, @ad8e you can likely repro this with just a
optim = torch.optim.AdamW(model.parameters(), lr=0.0, capturable=True)
...
optim.step()
The real solution is to allow foreach_div to support Scalar as the first argument, but I'm not sure how hard that is cc @crcrpar. It feels like we should be able to just add an overload. Regarding priority, I'm not sure this is high pri. How likely is this use case? Is there a real use case for having lr be 0?
from pytorch.
Is this actually related to DTensor or this is more about torch.compile + optimizer? Based on the analysis above, I think if we just use normal torch.Tensor and torch.compile, set the lr=0.0, we should still repro the issue?
from pytorch.
The underlying bug is not in DTensor; it's in the optimizer. It's only that DTensor exposes this code path in the optimizer.
Normal torch.Tensor and torch.compile with lr=0.0 doesn't hit it; it's the capturable
argument that Jane mentioned which is the key.
from pytorch.
@bdhirsh is this related to the torchtitan NaN loss you were talking about?
from pytorch.
@ad8e Does the NaN repro with single gpu?
from pytorch.
DTensor doesn't work when I change TP mesh size from 2 to 1: I receive
[rank1]: Traceback (most recent call last):
[rank1]: File "/clusterstorage/workspace/kevin/nandtensor.py", line 99, in <module>
[rank1]: model_tp = parallelize_module(model, tp_mesh, parallelize_plan=layer_plan)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/tensor/parallel/api.py", line 82, in parallelize_module
[rank1]: random._rng_tracker._manual_seed(device_mesh, base_seed=1234)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_tensor/random.py", line 345, in _manual_seed
[rank1]: tensor_parallel_rank = tp_mesh.get_local_rank()
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/device_mesh.py", line 502, in get_local_rank
[rank1]: mesh_dim_group = not_none(self.get_group(mesh_dim))
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/device_mesh.py", line 411, in get_group
[rank1]: _find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2])
[rank1]: IndexError: list index out of range
which means the process group isn't being created when the dim size is 1. So I cannot test if the NaN would appear or not with single GPU.
If I remove the DTensor, like so:
# model_tp = parallelize_module(model, tp_mesh, parallelize_plan=layer_plan)
model_tp = model
...
# gas_loss = gas_loss.full_tensor() # commented out
Then no NaNs appear. So the NaN only appears with DTensor.
It's not high priority for me because DTensor TP is currently useless due to low performance, so I don't use it anywhere. If DTensor actually mattered (above 70B scale, or if it finally gets comm/comp overlap working), then 0 LR would affect linear decay/warmup, in which case LR=0.0 is common at the endpoints, but avoidable. Another use case would be re-baking the AdamW second moment, which is necessary for resuming from a saved checkpoint without optimizer states, which is useful for saving disk space. This can be done using a very low LR instead of 0.0.
If anyone else cared about DTensor, they would be able to spot the NaN issue and work around it in both cases, since it is not a silent failure.
from pytorch.
I tried Jane's testcase, by taking the original DTensor TP=2 example, and making these modifications:
opt = AdamW(...
capturable=True, # this is new
)
...
# opt.step = torch.compile(opt.step) # this is removed
The NaNs appear. So her diagnosis is correct.
from pytorch.
Related Issues (20)
- Dynamo export: Fake tensor broadcast error HOT 2
- UNSTABLE trunk / linux-focal-cuda11.8-py3.10-gcc9-experimental-split-build-test / test (distributed) HOT 2
- UNSTABLE trunk / linux-focal-cuda11.8-py3.10-gcc9-experimental-split-build-test / test (distributed) HOT 2
- UNSTABLE trunk / linux-focal-cuda11.8-py3.10-gcc9-experimental-split-build-test / test (distributed) HOT 2
- How to Convert pytorch qat model to tensorrt HOT 1
- [AOTI] AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype'
- RVV support in PyTorch HOT 2
- Libtorch 2.3.1 requires Glibc_2.28 and can't be used on ubuntu 18. HOT 4
- my accuracy is not increasing HOT 1
- Support for 5-D output image tensors in `Col2Im`. HOT 1
- ChainedScheduler fail on CosineAnnealingWarmRestarts
- OSError: [WinError 126] The specified module could not be found. Error loading "\.venv\Lib\site-packages\torch\lib\fbgemm.dll" or one of its dependencies. HOT 2
- scaled_dot_product_attention fails on Ampere arch with head_dim > 128 HOT 2
- dynamo should recompile when a tensor subclass's inner tensor changes HOT 1
- torch._dynamo.exc.Unsupported: call_method NNModuleVariable() state_dict [] {} HOT 1
- [custom_op] torch.library.define should be able to auto-infer schema. HOT 1
- DynamicInt helper structure that is equivalent to mark_dynamic on an int HOT 1
- [Break XPU] Device-biase code in newlly added test_scatter_optimization.py cause XPU fail.
- Triton segfault with compiled flexattention kernel HOT 1
- On-demand Profiling causes torch segment fault. HOT 6
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.