Comments (4)
Narrowed down the root cause to the tree_map_only
used by _init_state_dict
.
https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/planner_helpers.py#L297-L303
def _init_state_dict(state_dict: STATE_DICT_TYPE) -> None:
state_dict_assigned_storage = tree_map_only(
torch.Tensor, lambda v: _init_meta_tensor(v), state_dict
)
# The inplace version of tree_map_only, tree_map_only_ doesn't seem to work.
# So we need to temporariy update the each element in the state dict with meta tensor.
for k in state_dict.keys():
state_dict[k] = state_dict_assigned_storage[k]
Doing the tree_map_only
out-of-place and then assigning it this way only maintains references for the leaves but not the branches. Which I guess breaks some assumptions made by the reading code. Using the inplace tree_map_only_
seems to resolve the issue, but im not sure if it breaks something else as suggested by @wz337 in the comment.
def _init_state_dict(state_dict: STATE_DICT_TYPE) -> None:
tree_map_only_(torch.Tensor, _init_meta_tensor, state_dict)
from pytorch.
Did a git bisect and found that f518cf8 is the first bad commit
@wz337
from pytorch.
@Jackmin801 For now, you can use dcp_state_dict["optimizer"]
to ensure getting the correct loaded state_dict. @wz337 I think our test cases do not catch this use case where we access the state_dict through the original reference.
from pytorch.
@fegin It seems like dcp_state_dict["optimizer"]
doesnt work either.
def load_checkpoint(checkpoint_path: str, model: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LambdaLR) -> dict:
"""Load the model and optimizer state from a checkpoint folder
Args:
checkpoint_path: the path to the checkpoint folder
model: the model to load
optimizer: the optimizer to load
scheduler: the scheduler to load
"""
# 1. Load distributed states
fs_storage_reader = dcp.FileSystemReader(checkpoint_path)
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
dcp_state_dict = {
"model": model_state_dict,
"optimizer": optimizer_state_dict,
}
dcp.load(dcp_state_dict, storage_reader=fs_storage_reader)
assert optimizer_state_dict["param_groups"][0]["lr"] == dcp_state_dict["optimizer"]["param_groups"][0]["lr"]
set_state_dict(model, optimizer, model_state_dict=dcp_state_dict["model"], optim_state_dict=dcp_state_dict["optimizer"])
# 2. Load global states
global_state_dict = torch.load(os.path.join(checkpoint_path, "global_state_dict.pt"))
scheduler.load_state_dict(global_state_dict["scheduler"])
from pytorch.
Related Issues (20)
- Incorrect index from torch.mode
- `python3 setup.py bdist_wheel` tries to write to /usr/local/... during build HOT 2
- PyTorch C++ API binary compiled with xmake crashes HOT 4
- [ExecutionTraceObserver] Tracer gets stuck using Pytorch 2.2 versions for some models using torch.compile
- [ONNX][low pri] Move old (non-public) implementation into legacy/ and schedule for deprecation
- `argsort()` can use the 0D tensor of a complex type value against error message HOT 1
- Upgrade dependencies MKL and Intel OpenMP to 2024.2.0 HOT 6
- The unexpected behavior of `argsort()`
- `msort()` can use the 0D tensor of a complex type value against error message HOT 1
- [TP+FSDP2] model weights become fully shard again after calling model.unshard() followed by dcp get_model_state_dict HOT 1
- `int` type for `dims` of `tile()` without `dims=` works with a tensor against the doc HOT 1
- `repeat_interleave()` without `repeats` argument and `input` keyword works HOT 1
- [export/dynamo] torch._check fails at compile time when the condition evaluates to False HOT 7
- Torch dynamo deep dive and overview discrepancy HOT 1
- _foreach_addc_
- Fuyou Training Framework Integration for PyTorch HOT 3
- Exporting the operator 'aten::fft_fft' to ONNX opset version 12 is not supported.
- torch.Tensor.register_hook() source link does not work HOT 3
- `start` and `step` of `arange()` should be optional on the doc HOT 1
- `end`, `start` and `step` argument of `arange()` work with a 0D tensor against error messages HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.