Comments (3)
Could this issue be resolved by changing the torch.save file to
{ "model": ModelWrapper(model) },
It could. But torch.save files are generally not in this format ({"model": }
), and it isn't useful for them to have that format, whose purpose is to cooperate with how DCP calls state dict functions. With the current API, it makes more sense for users to keep torch save files in their current format and then use _save_state_dict
to introduce the dictionary when converting to DCP.
Alternatively, we could do something like:
I cannot tell if this would work in the distributed setting. Some of the parallelization techniques (like FSDP1 and DTensor) shard the state dict, and without public documentation, we can only experiment which options work together.
I did attempt a few related methods, such as using sd = get_model_state_dict(model), dcp.load(sd, ...)
, but was not successful. There were a few moving parts, such as the model device, and maybe I just guessed the wrong options.
_load_state_dict_from_keys
This seems useful if I want to write a script to delete optimizers after training finishes. I agree, this should be in a different issue.
from pytorch.
Hi @ad8e , thanks for raising these issues.
Could you clarify the initial issue with torch_save_to_dcp
? Could this issue be resolved by changing the torch.save file to { "model": ModelWrapper(model) }
, or do we not have access to the generated file?
Alternatively, we could do something like:
sd = model.state_dict()
dcp.load(sd, ...)
DCP has a few gaps in being able to do what I want, such as saving model and optimizer to different directories (so I can delete the optimizer states after training to save disk space), or saving a distributed model's state dict without this dict-in-dict behavior.
This sounds pretty reasonable, but I believe deserves a separate issue if you're up for opening one :) The general recommendation at this point would be to use (torch.distributed.checkpoint._load_state_dict_from_keys)[https://github.com/pytorch/pytorch/blob/d15920a7d05ce881a0b45903e5b98f932ddd6439/torch/distributed/checkpoint/state_dict_loader.py#L230] to do something like:
sd = _load_state_dict_from_keys("model", ....)
dcp.save(sd, checkpoint_id=path)
from pytorch.
Good points @ad8e , we're taking some of these notes to the drawing board. Appreciate the feedback
from pytorch.
Related Issues (20)
- Network module memory is not released in libtorch
- torch.compile fails to preserve gradients when an input requiring grad is mutated HOT 1
- `torch.compile` fails with customized Triton Operator on Triton 2.2 HOT 8
- Request for in-place FFT
- Future LRScheduler implementation with Tensor support HOT 2
- Handling reasoning about rationals in symbolic shapes HOT 24
- [Dynamo] torch.cuda.device context manager doesn't work
- torch.compile warning message for pybind'ed c++ functions is very spammy
- torch.compile doesn't work well with custom triton kernel from Mamba HOT 2
- Linear is not deterministic even using deterministic algorithms and cpu HOT 10
- [Inductor] Fusion of Tiled Point-Wise and Reduction Operators
- Improve debugability of warnings/errors "Triggered internally at" HOT 5
- [dynamo] DAC: 'AudioSignal' object has no attribute 'sample_rate'
- reduce_scatter_tensor with strided inputs produces corrupted results HOT 1
- Dynamo Graph break in Unsupported: call_method ConstDictVariable()
- [user empathy day 2][based] torch.compile issues
- Map with multiple arguments not supported in Dynamo and causes graph breaks HOT 2
- inductor error when torch.compile on distrifuser
- [User Empathy Day 2] non-deterministic recompiles for ChatTTS model
- [user empathy day 2] dynamo raises exception when tracing super(Fraction, cls).__new__(cls)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.