Git Product home page Git Product logo

Comments (4)

Jackmin801 avatar Jackmin801 commented on July 2, 2024 1

Narrowed down the root cause to the tree_map_only used by _init_state_dict.
https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/planner_helpers.py#L297-L303

def _init_state_dict(state_dict: STATE_DICT_TYPE) -> None:
    state_dict_assigned_storage = tree_map_only(
        torch.Tensor, lambda v: _init_meta_tensor(v), state_dict
    )
    # The inplace version of tree_map_only, tree_map_only_ doesn't seem to work.
    # So we need to temporariy update the each element in the state dict with meta tensor.
    for k in state_dict.keys():
        state_dict[k] = state_dict_assigned_storage[k]

Doing the tree_map_only out-of-place and then assigning it this way only maintains references for the leaves but not the branches. Which I guess breaks some assumptions made by the reading code. Using the inplace tree_map_only_ seems to resolve the issue, but im not sure if it breaks something else as suggested by @wz337 in the comment.

def _init_state_dict(state_dict: STATE_DICT_TYPE) -> None:
    tree_map_only_(torch.Tensor, _init_meta_tensor, state_dict)

cc @fegin @LucasLLC

from pytorch.

Jackmin801 avatar Jackmin801 commented on July 2, 2024

Did a git bisect and found that f518cf8 is the first bad commit
@wz337

from pytorch.

fegin avatar fegin commented on July 2, 2024

@Jackmin801 For now, you can use dcp_state_dict["optimizer"] to ensure getting the correct loaded state_dict. @wz337 I think our test cases do not catch this use case where we access the state_dict through the original reference.

from pytorch.

Jackmin801 avatar Jackmin801 commented on July 2, 2024

@fegin It seems like dcp_state_dict["optimizer"] doesnt work either.

def load_checkpoint(checkpoint_path: str, model: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LambdaLR) -> dict:
   """Load the model and optimizer state from a checkpoint folder
   
   Args:
       checkpoint_path: the path to the checkpoint folder
       model: the model to load
       optimizer: the optimizer to load
       scheduler: the scheduler to load
   """
   # 1. Load distributed states
   fs_storage_reader = dcp.FileSystemReader(checkpoint_path)

   model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
   dcp_state_dict = {
       "model": model_state_dict,
       "optimizer": optimizer_state_dict,
   }
   dcp.load(dcp_state_dict, storage_reader=fs_storage_reader)
   assert optimizer_state_dict["param_groups"][0]["lr"] == dcp_state_dict["optimizer"]["param_groups"][0]["lr"]
   set_state_dict(model, optimizer, model_state_dict=dcp_state_dict["model"], optim_state_dict=dcp_state_dict["optimizer"])
   
   # 2. Load global states
   global_state_dict = torch.load(os.path.join(checkpoint_path, "global_state_dict.pt"))
   scheduler.load_state_dict(global_state_dict["scheduler"])

from pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.