lyhue1991 / torchkeras Goto Github PK
View Code? Open in Web Editor NEWPytorch❤️ Keras 😋😋
License: Apache License 2.0
Pytorch❤️ Keras 😋😋
License: Apache License 2.0
When I'm using a multi-input model (such as aTransformer model), I get errors like TypeError: forward() missing 2 required positional arguments: 'tgt' and 'tgt_mask'
.
I looked at the source code of torchkeras and thought that there was a problem with the way the parameters of the torchkeras.torchkeras.Model.forward()
function were passed.
So I suggest modifying the torchkeras.torchkeras.Model.forward()
function like this:
Change:
torchkeras/torchkeras/torchkeras.py
Lines 22 to 26 in 9d66849
To:
def forward(self, *x): #Attention this line
if self.net:
return self.net.forward(*x) #Attention this line
else:
raise NotImplementedError
Also, the place where the torchkeras.torchkeras.Model.forward()
function is used should also be modified, as follows:
Change:
torchkeras/torchkeras/torchkeras.py
Line 59 in 9d66849
To:
predictions = self.forward(*features)
and many more.
If my suggestion is wrong, please ignore my suggestion.
代码:官方给的实例
https://www.kaggle.com/code/lyhue1991/torchkeras-ddp-tpu-examples/notebook
运行环境:
pycharm 上的 run 按钮
我在多gpu 运行您的examples/ChatGLM2——transformers.ipynb 的训练脚本,报 如下错误:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0!
(when checking argument for argument target in method wrapper_CUDA_nll_loss_forward)
in :1 │
│ │
│ ❱ 1 keras_model.fit(train_data = dl_train, │
│ 2 │ │ │ │ val_data = dl_val, │
│ 3 │ │ │ │ epochs=100,patience=5, │
│ 4 │ │ │ │ monitor='val_loss',mode='min', │
│ │
│ /u01/liuys/anaconda3/envs/chatglm/lib/python3.11/site-packages/torchkeras/kerasmodel.py:204 in │
│ fit │
│ │
│ 201 │ │ │ │
│ 202 │ │ │ train_epoch_runner = self.EpochRunner(train_step_runner,should_quiet) │
│ 203 │ │ │ train_metrics = {'epoch':epoch} │
│ ❱ 204 │ │ │ train_metrics.update(train_epoch_runner(train_dataloader)) │
│ 205 │ │ │ │
│ 206 │ │ │ for name, metric in train_metrics.items(): │
│ 207 │ │ │ │ self.history[name] = self.history.get(name, []) + [metric] │
│ │
│ /u01/liuys/anaconda3/envs/chatglm/lib/python3.11/site-packages/torchkeras/kerasmodel.py:77 in │
│ call │
│ │
│ 74 │ │ │
│ 75 │ │ for step, batch in loop: │
│ 76 │ │ │ with self.accelerator.accumulate(self.net): │
│ ❱ 77 │ │ │ │ step_losses,step_metrics = self.steprunner(batch) │
│ 78 │ │ │ │ step_log = dict(step_losses,**step_metrics) │
│ 79 │ │ │ │ for k,v in step_losses.items(): │
│ 80 │ │ │ │ │ epoch_losses[k] = epoch_losses.get(k,0.0)+v │
│ │
│ in call:20 │
│ │
│ 17 │ │ │
│ 18 │ │ #loss │
│ 19 │ │ with self.accelerator.autocast(): │
│ ❱ 20 │ │ │ loss = self.net(input_ids=batch["input_ids"],labels=batch["labels"]).loss │
│ 21 │ │ │
│ 22 │ │ #backward() │
│ 23 │ │ if self.optimizer is not None and self.stage=="train": │
│ │
│ /u01/liuys/anaconda3/envs/chatglm/lib/python3.11/site-packages/torch/nn/modules/module.py:1501 │
│ in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /u01/liuys/anaconda3/envs/chatglm/lib/python3.11/site-packages/accelerate/utils/operations.py:52 │
│ 1 in forward │
│ │
│ 518 │ model_forward = ConvertOutputsToFp32(model_forward) │
│ 519 │ │
│ 520 │ def forward(*args, **kwargs): │
│ ❱ 521 │ │ return model_forward(*args, **kwargs) │
│ 522 │ │
│ 523 │ # To act like a decorator so that it can be popped when doing `extract_model_from_pa │
│ 524 │ forward.wrapped = model_forward │
│ │
│ /u01/liuys/anaconda3/envs/chatglm/lib/python3.11/site-packages/accelerate/utils/operations.py:50 │
│ 9 in call │
│ │
│ 506 │ │ update_wrapper(self, model_forward) │
│ 507 │ │
│ 508 │ def call(self, *args, **kwargs): │
│ ❱ 509 │ │ return convert_to_fp32(self.model_forward(*args, **kwargs)) │
│ 510 │ │
│ 511 │ def getstate(self): │
│ 512 │ │ raise pickle.PicklingError( │
│ │
│ /u01/liuys/anaconda3/envs/chatglm/lib/python3.11/site-packages/torch/amp/autocast_mode.py:14 in │
│ decorate_autocast │
│ │
│ 11 │ @functools.wraps(func) │
│ 12 │ def decorate_autocast(*args, **kwargs): │
│ 13 │ │ with autocast_instance: │
│ ❱ 14 │ │ │ return func(*args, **kwargs) │
│ 15 │ decorate_autocast.__script_unsupported = '@autocast() decorator is not supported in │
│ 16 │ return decorate_autocast │
│ 17 │
│ │
│ /u01/liuys/anaconda3/envs/chatglm/lib/python3.11/site-packages/peft/peft_model.py:686 in forward │
│ │
│ 683 │ ): │
│ 684 │ │ peft_config = self.active_peft_config │
│ 685 │ │ if not isinstance(peft_config, PromptLearningConfig): │
│ ❱ 686 │ │ │ return self.base_model( │
│ 687 │ │ │ │ input_ids=input_ids, │
│ 688 │ │ │ │ attention_mask=attention_mask, │
│ 689 │ │ │ │ inputs_embeds=inputs_embeds, │
│ │
│ /u01/liuys/anaconda3/envs/chatglm/lib/python3.11/site-packages/torch/nn/modules/module.py:1501 │
│ in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /u01/liuys/anaconda3/envs/chatglm/lib/python3.11/site-packages/accelerate/hooks.py:165 in │
│ new_forward │
│ │
│ 162 │ │ │ with torch.no_grad(): │
│ 163 │ │ │ │ output = old_forward(*args, **kwargs) │
│ 164 │ │ else: │
│ ❱ 165 │ │ │ output = old_forward(*args, **kwargs) │
│ 166 │ │ return module._hf_hook.post_forward(module, output) │
│ 167 │ │
│ 168 │ module.forward = new_forward │
│ │
│ /u01/liuys/.cache/huggingface/modules/transformers_modules/THUDM/chatglm2-6b/c57e892806dfe383cd5 │
│ caf09719628788fe96379/modeling_chatglm.py:957 in forward │
│ │
│ 954 │ │ │ shift_labels = labels[..., 1:].contiguous() │
│ 955 │ │ │ # Flatten the tokens │
│ 956 │ │ │ loss_fct = CrossEntropyLoss(ignore_index=-100) │
│ ❱ 957 │ │ │ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.v │
│ 958 │ │ │ │
│ 959 │ │ │ lm_logits = lm_logits.to(hidden_states.dtype) │
│ 960 │ │ │ loss = loss.to(hidden_states.dtype) │
│ │
│ /u01/liuys/anaconda3/envs/chatglm/lib/python3.11/site-packages/torch/nn/modules/module.py:1501 │
│ in _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /u01/liuys/anaconda3/envs/chatglm/lib/python3.11/site-packages/torch/nn/modules/loss.py:1174 in │
│ forward │
│ │
│ 1171 │ │ self.label_smoothing = label_smoothing │
│ 1172 │ │
│ 1173 │ def forward(self, input: Tensor, target: Tensor) -> Tensor: │
│ ❱ 1174 │ │ return F.cross_entropy(input, target, weight=self.weight, │
│ 1175 │ │ │ │ │ │ │ ignore_index=self.ignore_index, reduction=self.reduction, │
│ 1176 │ │ │ │ │ │ │ label_smoothing=self.label_smoothing) │
│ 1177 │
│ │
│ /u01/liuys/anaconda3/envs/chatglm/lib/python3.11/site-packages/torch/nn/functional.py:3029 in │
│ cross_entropy │
│ │
│ 3026 │ │ ) │
│ 3027 │ if size_average is not None or reduce is not None: │
│ 3028 │ │ reduction = _Reduction.legacy_get_string(size_average, reduce) │
│ ❱ 3029 │ return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(re │
│ 3030 │
│ 3031 │
│ 3032 def binary_cross_entropy(
Description:
链接失效:《Kaggle免费GPU使用攻略 https://www.bilibili.com/video/BV1oa411u7uR/%E3%80%8B
大家有遇到同样情况的嘛,感觉是loss fn出了问题
ValueError: Can't find 'adapter_config.json' at '/data/jupyter/root/torchkeras/examples/single_chatglm2'
代码:官方给的实例
https://www.kaggle.com/code/lyhue1991/torchkeras-ddp-tpu-examples/notebook
运行环境:
pycharm 上的 run 按钮
报错信息:
怎样使用gpu
where is this dataset "dfner_13k.pkl"
self.accelerator.backward(loss)
RuntimeError: Expected is_sm80 to be true, but got false.
accelerate=0.21 accelerate=0.20 都不行
hi, i get a problem when i use torch.summary function to print the multi-input model structure, i don't know how to pass parameters to "input_shape", i tried tuple like ((10, ), (10, 1)) and list like [(10, ), (10, )], can you help me?thanks in advance.
need the latest version
TypeError: ChatGLMForConditionalGeneration.stream_chat() takes from 3 to 9 positional arguments but 11 were given
多训练几次就没啥问题了!
TypeError Traceback (most recent call last)
Cell In[13], line 3
1 from transformers import AutoModel,AutoTokenizer
2 model_name = "/mnt/workspace/chatglm2-6b-AdaLoRA/chatglm2-6b-梦中情炉"
----> 3 tokenizer = AutoTokenizer.from_pretrained(
4 model_name, trust_remote_code=True)
5 model = AutoModel.from_pretrained(model_name,
6 trust_remote_code=True).half().cuda()
File /home/pai/lib/python3.9/site-packages/transformers/models/auto/tokenization_auto.py:679, in AutoTokenizer.from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs)
675 if tokenizer_class is None:
676 raise ValueError(
677 f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
678 )
--> 679 return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
681 # Otherwise we have to be creative.
682 # if model is an encoder decoder, the encoder tokenizer class is used by default
683 if isinstance(config, EncoderDecoderConfig):
File /home/pai/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:1804, in PreTrainedTokenizerBase.from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs)
1801 else:
1802 logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}")
-> 1804 return cls._from_pretrained(
1805 resolved_vocab_files,
1806 pretrained_model_name_or_path,
1807 init_configuration,
1808 *init_inputs,
1809 use_auth_token=use_auth_token,
1810 cache_dir=cache_dir,
1811 local_files_only=local_files_only,
1812 _commit_hash=commit_hash,
1813 **kwargs,
1814 )
File /home/pai/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:1958, in PreTrainedTokenizerBase._from_pretrained(cls, resolved_vocab_files, pretrained_model_name_or_path, init_configuration, use_auth_token, cache_dir, local_files_only, _commit_hash, *init_inputs, **kwargs)
1956 # Instantiate tokenizer.
1957 try:
-> 1958 tokenizer = cls(*init_inputs, **init_kwargs)
1959 except OSError:
1960 raise OSError(
1961 "Unable to load vocabulary from file. "
1962 "Please check that the provided vocabulary is accessible and not corrupted."
1963 )
File ~/.cache/huggingface/modules/transformers_modules/chatglm2-6b-梦中情炉/tokenization_chatglm.py:69, in ChatGLMTokenizer.init(self, vocab_file, padding_side, **kwargs)
68 def init(self, vocab_file, padding_side="left", **kwargs):
---> 69 super().init(padding_side=padding_side, clean_up_tokenization_spaces=False, **kwargs)
70 self.name = "GLMTokenizer"
72 self.vocab_file = vocab_file
TypeError: transformers.tokenization_utils.PreTrainedTokenizer.init() got multiple values for keyword argument 'clean_up_tokenization_spaces'
================================================================================2023-07-06 14:03:03
Epoch 1 / 100
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /torchkeras/train.py:230 in │
│ │
│ 227 │ │ optimizer=torch.optim.AdamW(model.parameters(),lr=2e-6)) │
│ 228 ckpt_path = './waimai_chatglm4' │
│ 229 │
│ ❱ 230 keras_model.fit(train_data = dl_train, │
│ 231 │ │ │ │ val_data = dl_val, │
│ 232 │ │ │ │ epochs=100,patience=5, │
│ 233 │ │ │ │ monitor='val_loss',mode='min', │
│ │
│ /torchkeras/torchkeras/kerasmodel.py:204 in fit │
│ │
│ 201 │ │ │ │ │ loss_fn = self.loss_fn, │
│ 202 │ │ │ │ │ accelerator = self.accelerator, │
│ 203 │ │ │ │ │ stage="train", │
│ ❱ 204 │ │ │ │ │ metrics_dict=deepcopy(self.metrics_dict), │
│ 205 │ │ │ │ │ optimizer = self.optimizer if epoch>0 else None, │
│ 206 │ │ │ │ │ lr_scheduler = self.lr_scheduler if epoch>0 else None │
│ 207 │ │ │ ) │
│ │
│ /usr/lib/python3.8/copy.py:172 in deepcopy │
│ │
│ 169 │ │ │ │ if isinstance(rv, str): │
│ 170 │ │ │ │ │ y = x │
│ 171 │ │ │ │ else: │
│ ❱ 172 │ │ │ │ │ y = _reconstruct(x, memo, *rv) │
│ 173 │ │
│ 174 │ # If is its own copy, don't memoize. │
│ 175 │ if y is not x: │
│ │
│ /usr/lib/python3.8/copy.py:270 in _reconstruct │
│ │
│ 267 │ │
│ 268 │ if state is not None: │
│ 269 │ │ if deep: │
│ ❱ 270 │ │ │ state = deepcopy(state, memo) │
│ 271 │ │ if hasattr(y, 'setstate'): │
│ 272 │ │ │ y.setstate(state) │
│ 273 │ │ else: │
│ │
│ /usr/lib/python3.8/copy.py:146 in deepcopy │
│ │
│ 143 │ │
│ 144 │ copier = _deepcopy_dispatch.get(cls) │
│ 145 │ if copier is not None: │
│ ❱ 146 │ │ y = copier(x, memo) │
│ 147 │ else: │
│ 148 │ │ if issubclass(cls, type): │
│ 149 │ │ │ y = _deepcopy_atomic(x, memo) │
│ │
│ /usr/lib/python3.8/copy.py:230 in _deepcopy_dict │
│ │
│ 227 │ y = {} │
│ 228 │ memo[id(x)] = y │
│ 229 │ for key, value in x.items(): │
│ ❱ 230 │ │ y[deepcopy(key, memo)] = deepcopy(value, memo) │
│ 231 │ return y │
│ 232 d[dict] = _deepcopy_dict │
│ 233 if PyStringMap is not None: │
│ │
│ /usr/lib/python3.8/copy.py:161 in deepcopy │
│ │
│ 158 │ │ │ │ else: │
│ 159 │ │ │ │ │ reductor = getattr(x, "reduce_ex", None) │
│ 160 │ │ │ │ │ if reductor is not None: │
│ ❱ 161 │ │ │ │ │ │ rv = reductor(4) │
│ 162 │ │ │ │ │ else: │
│ 163 │ │ │ │ │ │ reductor = getattr(x, "reduce", None) │
│ 164 │ │ │ │ │ │ if reductor: │
│ │
│ /usr/local/lib/python3.8/dist-packages/accelerate/utils/operations.py:498 in getstate │
│ │
│ 495 │ │ return convert_to_fp32(self.model_forward(*args, **kwargs)) │
│ 496 │ │
│ 497 │ def getstate(self): │
│ ❱ 498 │ │ raise pickle.PicklingError( │
│ 499 │ │ │ "Cannot pickle a prepared model with automatic mixed precision, please unwra │
│ 500 │ │ ) │
│ 501 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
PicklingError: Cannot pickle a prepared model with automatic mixed precision, please unwrap the model with Accelerator.unwrap_model(model)
before pickling it.
当我开始训练的时候报这个错误,能帮忙解决吗?
when I use model = nn.DataParallel(model, device_ids=model_device_ids), I got this Error:torch.nn.modules.module.ModuleAttributeError: 'DataParallel' object has no attribute 'compile'
can you fix this problem, Thank you very much!!
AttributeError: 'ChatGLMTokenizer' object has no attribute 'build_prompt'
是否能支持多卡训练
torchkeras/torchkeras/kerasmodel.py
Line 166 in 125bdc3
i set device='cuda:0' in "model.compile()", but the "model.fit()" will get a error: "RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cpu!" . I don't know why. when i don't set device , it will run on cpu correctly, but i want to run the code on gpu so i set the param "device", why it gets wrong?...
版本3.8.0,import torchkeras时总是出现这个,如果直接import torchkeras.LightModel也会报错,看了一下LightModel定义在lightmodel.py中,但为什么import不成功呢?
我魔改了一下,不支持呢
torchkeras/torchkeras/kerasmodel.py
Line 52 in 125bdc3
3.9.2版本的torchkeras中运行from torchkeras import LightModel 会出现cannot import name 'LightModel' from 'torchkeras',之前版本是可以的,最新版LightModel被去除了吗?
提示torchkeras No module named 'accelerate'
pip安装accelerate后
提示AttributeError: module 'signal' has no attribute 'SIGKILL'
我的环境是window,py3.9.13,torchkeras3.2.5
不知道是peft的哪个版本才有 AdaLoraConfig,我用的是0.3.0
ImportError: cannot import name 'AdaLoraConfig' from 'peft'
您好,自定义loss需要传入其他参数,但是StepRunner中self.loss_fn(preds,labels)只接受预测和标签,除了修改源码外,还有其他更优雅的方法吗
python 3.7.6
torchkeras 3.8.8
是版本不对吗?
执行
ds_train = ds_train_raw.map(
preprocess,
batched=True,
num_proc=4,
remove_columns=ds_train_raw.column_names
)
ds_val = ds_val_raw.map(
preprocess,
batched=True,
num_proc=4,
remove_columns=ds_val_raw.column_names
)
报错
AttributeError: 'ChatGLMTokenizer' object has no attribute 'build_prompt'
`from model import LeNet
import torchkeras
import torchmetrics
from torchvision import datasets
import torch.nn as nn
import torch
import torchvision.transforms as transforms
import torchvision
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 4
net = LeNet()
model = torchkeras.KerasModel(net,
loss_fn = nn.BCEWithLogitsLoss(),
optimizer= torch.optim.Adam(net.parameters(),lr = 1e-4),
metrics_dict = {"acc":torchmetrics.Accuracy(task='binary')}
)
trainset = torchvision.datasets.CIFAR10(root='../../../../Dataset', train=True,
download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='../../../../Dataset', train=False,
download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=2)
dfhistory=model.fit(train_data=trainloader,
val_data=testloader,
epochs=20,
patience=3,
ckpt_path='checkpoint.pt',
monitor="val_acc",
mode="max",
plot=True,
)`
# 通过注册jupyter魔法命令可以很方便地在jupyter中测试ChatGLM
from torchkeras.chat import ChatGLM
chatglm = ChatGLM(model, tokenizer)
register magic %%chatglm sucessed ...
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[5], line 3
1 # 通过注册jupyter魔法命令可以很方便地在jupyter中测试ChatGLM
2 from torchkeras.chat import ChatGLM
----> 3 chatglm = ChatGLM(model, tokenizer)
File ~/anaconda3/envs/zdw/lib/python3.10/site-packages/torchkeras/chat/chatglm.py:27, in ChatGLM.__init__(self, model, tokenizer, stream, max_chat_rounds, history, max_length, num_beams, do_sample, top_p, temperature, logits_processor)
24 print('register magic %%chatglm failed ...')
25 print(err)
---> 27 response = self('你好')
28 if not self.stream:
29 print(response)
File ~/anaconda3/envs/zdw/lib/python3.10/site-packages/torchkeras/chat/chatglm.py:50, in ChatGLM.__call__(self, query)
43 return response
45 result = self.model.stream_chat(self.tokenizer,
46 query,self.history,None,self.max_length,
47 self.do_sample,self.top_p,self.temperature,
48 self.logits_processor,None)
---> 50 for response,history in result:
51 print(response)
52 clear_output(wait=True)
File ~/anaconda3/envs/zdw/lib/python3.10/site-packages/torch/utils/_contextlib.py:26, in _wrap_generator.<locals>.generator_context(*args, **kwargs)
24 @functools.wraps(func)
25 def generator_context(*args, **kwargs):
---> 26 gen = func(*args, **kwargs)
28 # Generators are suspended and unsuspended at `yield`, hence we
29 # make sure the grad mode is properly set every time the execution
30 # flow returns into the wrapped generator and restored when it
31 # returns through our `yield` to our caller (see PR #49017).
32 try:
33 # Issuing `None` to a generator fires it up
TypeError: ChatGLMForConditionalGeneration.stream_chat() takes from 3 to 9 positional arguments but 11 were given
请问是哪里出问题了啊
首先感谢开源这么有意义的项目!
发现torchkeras对对学习率衰减支持得不好,首先是lr_scheduler.step()没传入参数metrics,其次是使用lr_scheduler后,学习率明明已经下降,但processing bar显示的lr仍然不变…… 改起来应该不难,希望作者更新fix
我看教程好像是自动选择的,但是我实际使用中,还是显示的cpu,torchkeras版本已经更新到最新
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.