Comments (9)
@pytorchbot label "oncall: distributed"
from pytorch.
from pytorch.
@mvpatel2000 Is model_state_dict
that is fed to set_model_state_dict
loaded from DCP? If not, then you are loading a state_dict
that is saved with ShardedTensor and loaded (without conversion) to a FSDP that is initialized with DTensor. You will need some conversion first, which is generally done in DCP.
If model_state_dict
is loaded from DCP, then we need to debug why the conversion does not happen.
from pytorch.
If
model_state_dict
is loaded from DCP, then we need to debug why the conversion does not happen.
@fegin yes we are loading from DCP. Here is snippet:
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=storage_reader,
planner=None,
process_group=process_group,
)
from pytorch.
Can you check if state_dict
contains ShardedTensor
before being passed into load_state_dict
?
from pytorch.
@fegin Yes it has ShardedTensors in it.
Code:
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=storage_reader,
planner=state.fsdp_config['load_planner'],
no_dist=(not dist.is_initialized()),
)
print('Dist CP loaded state_dict', state_dict)
Loaded state dict {'state': {'model': OrderedDict([('module.0.weight', ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[16, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[16, 0], shard_sizes=[16, 32], placement=rank:1/cuda:1)], size=torch.Size([32, 32]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False)))), ('module.0.bias', ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0], shard_sizes=[16], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[16], shard_sizes=[16], placement=rank:1/cuda:1)], size=torch.Size([32]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False)))), ('module.2.weight', ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[4, 0], shard_sizes=[4, 32], placement=rank:1/cuda:1)], size=torch.Size([8, 32]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))))]), 'optimizers': {'Adam': {'state': {'module.0.weight': {'exp_avg': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[16, 32], placement=rank:0/cpu), ShardMetadata(shard_offsets=[16, 0], shard_sizes=[16, 32], placement=rank:1/cpu)], size=torch.Size([32, 32]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'exp_avg_sq': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[16, 32], placement=rank:0/cpu), ShardMetadata(shard_offsets=[16, 0], shard_sizes=[16, 32], placement=rank:1/cpu)], size=torch.Size([32, 32]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'step': tensor(2.)}, 'module.0.bias': {'exp_avg': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0], shard_sizes=[16], placement=rank:0/cpu), ShardMetadata(shard_offsets=[16], shard_sizes=[16], placement=rank:1/cpu)], size=torch.Size([32]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'exp_avg_sq': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0], shard_sizes=[16], placement=rank:0/cpu), ShardMetadata(shard_offsets=[16], shard_sizes=[16], placement=rank:1/cpu)], size=torch.Size([32]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), ...
from pytorch.
No, I meant to print state_dict
before load_state_dict
. But I think the answer will be the same given that load_state_dict
is doing in-place. So the issue is where you get state_dict
. It should be from the the model you are going to load.
state_dict = model.state_dict()
dcp_cp.load_state_dict(state_dict, ...)
model.load_state_dict(state_dict)
So if state_dict
has ShardedTensor that means the FSDP model
is configured to have ShardedTensor not DTensor. Then that should not cause errors. If you want to load into DTensor, then state_dict
should also contain DTensor.
from pytorch.
@fegin apologies, I have verified it was actually going down another code path and not using DCP. I believe we use this path for legacy sharded checkpoints (non-elastic). I will dig further on my end -- I assume the right solution is to reroute these to DCP.
from pytorch.
@mvpatel2000 Thanks for the update. I'll close the issue for now. Feel free to reopen the issue if you still see the problem.
from pytorch.
Related Issues (20)
- Async compilation raise ressources error
- loss is nan and gradient is None when using torchaudio.functional.lfilter HOT 2
- The ProcessGroupNCCL is not being destructed
- Dynamo graph breaks on boxed ProcessGroup
- Compiling tensor subclasses fails when using custom op with effect tokens HOT 4
- DISABLED test_module_to_empty_cuda_float32 (__main__.TestNNDeviceTypeCUDA) HOT 2
- Mark Saved Activations As Donated Buffers to Inductor HOT 2
- Investigate alternatives to remove mask from triton sort kernel
- `test_dummy_mha_with_nt_cuda` fails on `sm70`, `sm75`
- Dynamo export: unsupported FX node 'aten.expm1.default' HOT 4
- Error: CUDA error: CUDA_ERROR_INVALID_VALUE cuMemcpyDtoH failed with Halide GPU Backend HOT 2
- Dynamo export: Fake tensor broadcast error HOT 2
- UNSTABLE trunk / linux-focal-cuda11.8-py3.10-gcc9-experimental-split-build-test / test (distributed) HOT 1
- UNSTABLE trunk / linux-focal-cuda11.8-py3.10-gcc9-experimental-split-build-test / test (distributed) HOT 1
- UNSTABLE trunk / linux-focal-cuda11.8-py3.10-gcc9-experimental-split-build-test / test (distributed) HOT 1
- How to Convert pytorch qat model to tensorrt HOT 1
- [AOTI] AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype'
- RVV support in PyTorch
- Libtorch 2.3.1 requires Glibc_2.28 and can't be used on ubuntu 18. HOT 4
- my accuracy is not increasing HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.