Git Product home page Git Product logo

Comments (9)

mvpatel2000 avatar mvpatel2000 commented on July 3, 2024 1

@pytorchbot label "oncall: distributed"

from pytorch.

awgu avatar awgu commented on July 3, 2024

cc: @fegin @wz337 @LucasLLC

from pytorch.

fegin avatar fegin commented on July 3, 2024

@mvpatel2000 Is model_state_dict that is fed to set_model_state_dict loaded from DCP? If not, then you are loading a state_dict that is saved with ShardedTensor and loaded (without conversion) to a FSDP that is initialized with DTensor. You will need some conversion first, which is generally done in DCP.

If model_state_dict is loaded from DCP, then we need to debug why the conversion does not happen.

from pytorch.

mvpatel2000 avatar mvpatel2000 commented on July 3, 2024

If model_state_dict is loaded from DCP, then we need to debug why the conversion does not happen.

@fegin yes we are loading from DCP. Here is snippet:

dist_cp.load_state_dict(
    state_dict=state_dict,
    storage_reader=storage_reader,
    planner=None,
    process_group=process_group,
)

from pytorch.

fegin avatar fegin commented on July 3, 2024

Can you check if state_dict contains ShardedTensor before being passed into load_state_dict?

from pytorch.

mvpatel2000 avatar mvpatel2000 commented on July 3, 2024

@fegin Yes it has ShardedTensors in it.

Code:

            dist_cp.load_state_dict(
                state_dict=state_dict,
                storage_reader=storage_reader,
                planner=state.fsdp_config['load_planner'],
                no_dist=(not dist.is_initialized()),
            )

            print('Dist CP loaded state_dict', state_dict)

Loaded state dict {'state': {'model': OrderedDict([('module.0.weight', ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[16, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[16, 0], shard_sizes=[16, 32], placement=rank:1/cuda:1)], size=torch.Size([32, 32]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False)))), ('module.0.bias', ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0], shard_sizes=[16], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[16], shard_sizes=[16], placement=rank:1/cuda:1)], size=torch.Size([32]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False)))), ('module.2.weight', ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[4, 0], shard_sizes=[4, 32], placement=rank:1/cuda:1)], size=torch.Size([8, 32]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))))]), 'optimizers': {'Adam': {'state': {'module.0.weight': {'exp_avg': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[16, 32], placement=rank:0/cpu), ShardMetadata(shard_offsets=[16, 0], shard_sizes=[16, 32], placement=rank:1/cpu)], size=torch.Size([32, 32]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'exp_avg_sq': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[16, 32], placement=rank:0/cpu), ShardMetadata(shard_offsets=[16, 0], shard_sizes=[16, 32], placement=rank:1/cpu)], size=torch.Size([32, 32]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'step': tensor(2.)}, 'module.0.bias': {'exp_avg': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0], shard_sizes=[16], placement=rank:0/cpu), ShardMetadata(shard_offsets=[16], shard_sizes=[16], placement=rank:1/cpu)], size=torch.Size([32]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'exp_avg_sq': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0], shard_sizes=[16], placement=rank:0/cpu), ShardMetadata(shard_offsets=[16], shard_sizes=[16], placement=rank:1/cpu)], size=torch.Size([32]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), ...

from pytorch.

fegin avatar fegin commented on July 3, 2024

No, I meant to print state_dict before load_state_dict. But I think the answer will be the same given that load_state_dict is doing in-place. So the issue is where you get state_dict. It should be from the the model you are going to load.

state_dict = model.state_dict()
dcp_cp.load_state_dict(state_dict, ...)
model.load_state_dict(state_dict)

So if state_dict has ShardedTensor that means the FSDP model is configured to have ShardedTensor not DTensor. Then that should not cause errors. If you want to load into DTensor, then state_dict should also contain DTensor.

from pytorch.

mvpatel2000 avatar mvpatel2000 commented on July 3, 2024

@fegin apologies, I have verified it was actually going down another code path and not using DCP. I believe we use this path for legacy sharded checkpoints (non-elastic). I will dig further on my end -- I assume the right solution is to reroute these to DCP.

from pytorch.

fegin avatar fegin commented on July 3, 2024

@mvpatel2000 Thanks for the update. I'll close the issue for now. Feel free to reopen the issue if you still see the problem.

from pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.