lansinuote / huggingface_toturials Goto Github PK
View Code? Open in Web Editor NEWbert-base-chinese example
bert-base-chinese example
Int[3]:
labels = torch.LongTensor(labels)
代码地址
这里labels是str类型, 无法转换成 LongTensor
def collate_fn(data):
# print('data: %s' % data)
# 取每个评论的第一个字符作为数据, 第二作为标签
# 猫和老鼠的DVD,我在当当网已买过10余次了。除了做为礼物送给亲朋好有的孩子外,...
sents = [i[0] for i in data] # 猫
labels = [i[1] for i in data] # 和
# print('sents: %s sents长度: %s' % (sents, len(sents)))
# print('labels: %s' % labels)
# 编码
data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
truncation=True,
padding='max_length',
max_length=500,
return_tensors='pt',
return_length=True)
labels_data = token.batch_encode_plus(batch_text_or_text_pairs=labels,
truncation=True,
padding='max_length',
max_length=500,
return_tensors='pt',
return_length=True)
# input_ids:编码之后的数字
# attention_mask:是补零的位置是0,其他位置是1
input_ids = data['input_ids']
attention_mask = data['attention_mask']
token_type_ids = data['token_type_ids']
n_labels = torch.LongTensor(labels_data['input_ids'])
#print(data['length'], data['length'].max())
return input_ids, attention_mask, token_type_ids, n_labels
这里再打印 'labels' 是这样的
# 数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset['train'],
batch_size=16,
collate_fn=collate_fn,
shuffle=True,
drop_last=True)
for i, values in enumerate(loader):
input_ids, attention_mask, token_type_ids, labels = values
# print(i, values)
print(i, input_ids.shape,
attention_mask.shape,
token_type_ids.shape,
labels.shape,
)
if i >= 5:
break
print(len(loader))
输出结果
0 torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16, 500])
1 torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16, 500])
2 torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16, 500])
3 torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16, 500])
4 torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16, 500])
5 torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16, 500]) torch.Size([16, 500])
600
计算代码, (转到GPU加速)
from transformers import AdamW
# 训练
optimizer = AdamW(model.parameters(), lr=5e-4)
loss = torch.nn.CrossEntropyLoss()
model.train()
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
token_type_ids = token_type_ids.to(device)
labels = labels.to(device)
out = model(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
# 梯度下降
l = loss(out, labels)
optimizer.zero_grad()
l.backward()
optimizer.step()
if i % 5 == 0:
out = out.cpu()
labels = labels.cpu()
out = out.argmax(dim=1)
accuracy = (out == labels).sum().item() / len(labels)
print(i, l.item(), accuracy)
if i == 300:
break
RuntimeError: 0D or 1D target tensor expected, multi-target not supported
/home/mylady/.virtualenvs/dl-pytorch/lib/python3.8/site-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
warnings.warn(
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[10], line 24
19 out = model(input_ids=input_ids,
20 attention_mask=attention_mask,
21 token_type_ids=token_type_ids)
23 # 梯度下降
---> 24 l = loss(out, labels)
25 optimizer.zero_grad()
26 l.backward()
File ~/.virtualenvs/dl-pytorch/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.virtualenvs/dl-pytorch/lib/python3.8/site-packages/torch/nn/modules/loss.py:1174, in CrossEntropyLoss.forward(self, input, target)
1173 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1174 return F.cross_entropy(input, target, weight=self.weight,
1175 ignore_index=self.ignore_index, reduction=self.reduction,
1176 label_smoothing=self.label_smoothing)
File ~/.virtualenvs/dl-pytorch/lib/python3.8/site-packages/torch/nn/functional.py:3026, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
3024 if size_average is not None or reduce is not None:
3025 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3026 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: 0D or 1D target tensor expected, multi-target not supported
如题🤣
up主优化器里没加过滤器filter把requires_grad = False的参数过滤掉啊,这样冻结不就没有用了吗?
第一个案例collate_fn函数传入data参数,请问这个data参数在哪里定义了
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.