Git Product home page Git Product logo

Comments (3)

ad8e avatar ad8e commented on June 20, 2024 1

Could this issue be resolved by changing the torch.save file to { "model": ModelWrapper(model) },

It could. But torch.save files are generally not in this format ({"model": }), and it isn't useful for them to have that format, whose purpose is to cooperate with how DCP calls state dict functions. With the current API, it makes more sense for users to keep torch save files in their current format and then use _save_state_dict to introduce the dictionary when converting to DCP.

Alternatively, we could do something like:

I cannot tell if this would work in the distributed setting. Some of the parallelization techniques (like FSDP1 and DTensor) shard the state dict, and without public documentation, we can only experiment which options work together.

I did attempt a few related methods, such as using sd = get_model_state_dict(model), dcp.load(sd, ...) , but was not successful. There were a few moving parts, such as the model device, and maybe I just guessed the wrong options.

_load_state_dict_from_keys

This seems useful if I want to write a script to delete optimizers after training finishes. I agree, this should be in a different issue.

from pytorch.

LucasLLC avatar LucasLLC commented on June 20, 2024

Hi @ad8e , thanks for raising these issues.

Could you clarify the initial issue with torch_save_to_dcp? Could this issue be resolved by changing the torch.save file to { "model": ModelWrapper(model) }, or do we not have access to the generated file?

Alternatively, we could do something like:

sd = model.state_dict()
dcp.load(sd, ...)

DCP has a few gaps in being able to do what I want, such as saving model and optimizer to different directories (so I can delete the optimizer states after training to save disk space), or saving a distributed model's state dict without this dict-in-dict behavior.

This sounds pretty reasonable, but I believe deserves a separate issue if you're up for opening one :) The general recommendation at this point would be to use (torch.distributed.checkpoint._load_state_dict_from_keys)[https://github.com/pytorch/pytorch/blob/d15920a7d05ce881a0b45903e5b98f932ddd6439/torch/distributed/checkpoint/state_dict_loader.py#L230] to do something like:

sd = _load_state_dict_from_keys("model", ....)
dcp.save(sd, checkpoint_id=path)

from pytorch.

LucasLLC avatar LucasLLC commented on June 20, 2024

Good points @ad8e , we're taking some of these notes to the drawing board. Appreciate the feedback

from pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.