Comments (9)
你好,我想继续训练gpt3
的1.3b
版本,但是出现Missing key(s) in state_dict
训练代码来自续写训练,仅修改最后一句训练代码为继续训练:
trainer.train(os.path.join(tmp_dir, "epoch_12.pth"))
以下是报错日志:
Traceback (most recent call last):
File "train_test.py", line 63, in <module>
trainer.train(os.path.join(tmp_dir, "epoch_12.pth"))
File "/root/miniconda3/lib/python3.8/site-packages/modelscope/trainers/trainer.py", line 495, in train
self.train_loop(self.train_dataloader)
File "/root/miniconda3/lib/python3.8/site-packages/modelscope/trainers/trainer.py", line 872, in train_loop
self.invoke_hook(TrainerStages.before_run)
File "/root/miniconda3/lib/python3.8/site-packages/modelscope/trainers/trainer.py", line 1034, in invoke_hook
getattr(hook, fn_name)(self)
File "/root/miniconda3/lib/python3.8/site-packages/modelscope/trainers/hooks/checkpoint_hook.py", line 82, in before_run
meta = self.load_checkpoint(self.checkpoint_file, trainer,
File "/root/miniconda3/lib/python3.8/site-packages/modelscope/trainers/hooks/checkpoint_hook.py", line 120, in load_checkpoint
meta = load_checkpoint(
File "/root/miniconda3/lib/python3.8/site-packages/modelscope/utils/checkpoint.py", line 136, in load_checkpoint
model.load_state_dict(state_dict)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for GPT3ForTextGeneration:
Missing key(s) in state_dict: "model.dist_model.language_model.embedding.word_embeddings.weight", ......
......
Unexpected key(s) in state_dict: "model.language_model.embedding.word_embeddings.weight", ......
......
请问该如何处理?
from modelscope.
你用的bert-base是hf的么?是否可以把你的configuration.json给出来?以及你的cfg_modify_fn里做了哪些修改也给出来,谢谢。
from modelscope.
修改的是structure_bert-backbond模型
修改的内容是 修改了其中的optimizer的type以及lr
其中把 model_id 修改为对应的 output_path 存在configuration.json的路径
cfg_modift_fn的修改 大概是
cfg.train.optimizer.type = "SDG"
cfg.train.optimizer.lr = 2e-5
from modelscope.
那有什么报错信息么?应该是主要是SDG这个没有注册进去导致的?
from modelscope.
请问这样继续训练是对的吗,我trainer.train(os.path.join('./workdirs/damoyolo_s', 'epoch_1_ckpt.pth'))着用会报错TypeError: train() takes 1 positional argument but 2 were given
from modelscope.
你的modelscope版本是多少?
另外还要看你训练的模型是哪个,如果是nlp的这样子是对的,如果是cv的,可能有些cv特定的train入参没有支持传入checkpoint的设置。
你具体是哪个任务进行finetune,可以发出来看一下。
from modelscope.
damoyolo这个train确实没有支持传入,这里我们后续统一调整一下。
from modelscope.
用了一个比较粗暴的办法, 将model.language_model
全部替换为model.dist_model.language_model
,已经可以继续训练了.
顺便贴上简单转换的代码以供参考, 我的modelscope
版本是1.2.1
import torch
checkpoint = torch.load("somewhere/epoch_12.pth", map_location='cpu')
state_dict = checkpoint["state_dict"]
new_key_dict = {}
for old_key in state_dict.keys():
if old_key.startswith("model.language_model."):
new_key = old_key.replace("model.language_model.", "model.dist_model.language_model.")
new_key_dict[new_key] = old_key
for new_key, old_key in new_key_dict.items():
state_dict.update({new_key: state_dict.pop(old_key)})
# 1.2.1版本保存的epoch如果不加1会多训练一轮
checkpoint["meta"]["epoch"] += 1
checkpoint["meta"]["inner_iter"] = 0
torch.save(checkpoint, "somewhere/epoch_12_format.pth")
不知道你们官方有没有更优雅的解决办法
from modelscope.
This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.
from modelscope.
Related Issues (20)
- TTS流式合成功能的需求 HOT 1
- py_sound_connect无法安装 HOT 1
- py_sound_connect无法安装 HOT 1
- question about installing modelscope from source using setup.py HOT 1
- 页面下载文件提示的指令错误 HOT 1
- 训练好的模型加载出错 HOT 1
- 1.15版本存在问题 HOT 19
- 限制GPU显存占用 HOT 2
- MsDataset.load报错 HOT 1
- 安装audio报错 HOT 1
- 已经有音频和对应的文本,如何使用run_auto_label实现音素标注呢? HOT 3
- modelscope 的 dataset、model和space 目前不支持个人仓库移交给组织的功能么? HOT 2
- 希望datasets提供造成viewer不显示的debug logs查看接口 HOT 4
- FileNotFoundError: Cannot find dataset meta-files, please fetch meta from modelscope hub. HOT 1
- 对训练好的模型微调时报错:RuntimeError: probability tensor contains either `inf`, `nan` or element < 0 HOT 1
- 如果同时用到transformer和modelscope,会出现在~/.cache/huggingface中找不到模型的错误。 HOT 1
- MODELSCOPE_MODULES_CACHE 这个环境变量还需要设置吗?在modelscope代码里没有搜到这个变量 HOT 2
- 创空间有的登录后可查看,有的公网可看,且休眠周期不一致 HOT 1
- modelscope[audio] 安装报错 HOT 1
- 用个性化语音合成-自动标注模型-16k进行音素标注的时候,为什么会产生badlist呢
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from modelscope.