Git Product home page Git Product logo

bert4torch's Introduction

bert4torch

一款用pytorch来复现bert4keras的简洁训练框架

licence GitHub release PyPI PyPI - Downloads GitHub stars GitHub Issues contributions welcome

Documentation | Torch4keras | Examples

1. 下载安装

安装稳定版

pip install bert4torch

安装最新版

pip install git+https://github.com/Tongjilibo/bert4torch
  • 注意事项:pip包的发布慢于git上的开发版本,git clone注意引用路径,注意权重是否需要转换
  • 测试用例git clone https://github.com/Tongjilibo/bert4torch,修改example中的预训练模型文件路径和数据路径即可启动脚本
  • 自行训练:针对自己的数据,修改相应的数据处理代码块
  • 开发环境:使用torch==1.10版本进行开发,如其他版本遇到不适配,欢迎反馈

2. 功能

  • LLM模型: 加载chatglm-6b和llama-7b进行推理和finetune

  • 核心功能:加载bert、roberta、albert、xlnet、nezha、bart、RoFormer、RoFormer_V2、ELECTRA、GPT、GPT2、T5、GAU-alpha、ERNIE等预训练权重继续进行finetune、并支持在bert基础上灵活定义自己模型

  • 丰富示例:包含pretrainsentence_classficationsentence_embeddingsequence_labelingrelation_extractionseq2seqserving等多种解决方案

  • 实验验证:已在公开数据集实验验证,使用如下examples数据集

  • 易用trick:集成了常见的trick,即插即用

  • 其他特性加载transformers库模型一起使用;调用方式简洁高效;有训练进度条动态展示;配合torchinfo打印参数量;默认Logger和Tensorboard简便记录训练过程;自定义fit过程,满足高阶需求

  • 训练过程

    2022-10-28 23:16:10 - Start Training
    2022-10-28 23:16:10 - Epoch: 1/5
    5000/5000 [==============================] - 13s 3ms/step - loss: 0.1351 - acc: 0.9601
    Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 798.09it/s] 
    test_acc: 0.98045. best_test_acc: 0.98045
    
    2022-10-28 23:16:27 - Epoch: 2/5
    5000/5000 [==============================] - 13s 3ms/step - loss: 0.0465 - acc: 0.9862
    Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 635.78it/s] 
    test_acc: 0.98280. best_test_acc: 0.98280
    
    2022-10-28 23:16:44 - Epoch: 3/5
    5000/5000 [==============================] - 15s 3ms/step - loss: 0.0284 - acc: 0.9915
    Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 673.60it/s] 
    test_acc: 0.98365. best_test_acc: 0.98365
    
    2022-10-28 23:17:03 - Epoch: 4/5
    5000/5000 [==============================] - 15s 3ms/step - loss: 0.0179 - acc: 0.9948
    Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 692.34it/s] 
    test_acc: 0.98265. best_test_acc: 0.98365
    
    2022-10-28 23:17:21 - Epoch: 5/5
    5000/5000 [==============================] - 14s 3ms/step - loss: 0.0129 - acc: 0.9958
    Evaluate: 100%|██████████████████████████████████████████████████| 2500/2500 [00:03<00:00, 701.77it/s] 
    test_acc: 0.98585. best_test_acc: 0.98585
    
    2022-10-28 23:17:37 - Finish Training
    

3. 快速上手

4. 版本说明

4.1 更新历史

  • v0.2.8(待发布):增加chatglm-6b/llama-7b预训练模型, 修改rope为不使用max_position,修复model.half()类型不一致问题,生成式解码新增SeqGeneration和Seq2SeqGeneration, 支持加载多个权重文件, gpt系列默认不加softmax,更新fnlp的bart2.0。增加苏神Tiger的pytorch实现, 集成苏神、uer的roberta-small/Tiny模型以及ChatYuan v2模型, 增加了对attention_key_size的入参支持(skykiseki用户)
  • v0.2.7.post2:20230310 增加lion优化器, 修复albert_unshared加载权重, 修复lm系列(gpt, seq2seq)存在的forward参数不对的问题,修复GlobalPointer使用rope的bug
  • v0.2.7:20230213 修复random_sample()的bug,适配v0.0.6的torch4keras:增加resume_from_checkpoint和save_to_checkpoint;增加add_trainer方法,重构了Trainer(BaseModel)的实现,增加了AccelerateCallback
  • v0.2.6:20221231 build_transformer_model需显式指定add_trainer才从BaseModel继承, 增加guwenbert, macbert,text2vec-bert-chinese, wobert预训练模型,允许position_ids从padding开始, transformer.configs支持点操作,可以使用torch4keras的Trainer(net)来初始化, 修复tokenizer的切分subtoken的bug, 允许embedding_size!=hidden_size
  • v0.2.5:20221127 对抗训练从compile转为使用Callback来实现,修复1.7.1版本兼容bug, uie模型内置
  • v0.2.4:20221120 删除SpTokenizer基类中的rematch, 增加deberta_v2模型
  • v0.2.3:20221023 虚拟对抗VAT在多个ouput时支持指定,把Trainer抽象到torch4keras中,修复DP和DDP出现resume_epoch不存在的bug, tokenizer的never_split去除None, transformer_xl的bug, 增加gradient_checkpoint
  • v0.2.2:20220922 修复t5的norm_mode问题,允许hidden_size不整除num_attention_heads,支持多个schedule(如同时ema+warmup)
  • v0.2.1:20220905 兼容torch<=1.7.1的torch.div无rounding_mode,增加自定义metrics,支持断点续训,增加默认Logger和Tensorboard日志
  • v0.2.0:20220823 兼容torch<1.9.0的缺失take_along_dim,修复bart中位置向量514的问题,修复Sptokenizer对符号不转换,打印Epoch开始的时间戳,增加parallel_apply
  • v0.1.9:20220808 增加mixup/manifold_mixup/temporal_ensembling策略,修复pgd策略param.grad为空的问题,修改tokenizer支持批量
  • v0.1.8:20220717 修复原来CRF训练中loss陡增的问题,修复xlnet的token_type_ids输入显存占用大的问题
  • v0.1.7:20220710 增加EarlyStop,CRF中自带转bool类型
  • v0.1.6:20220605 增加transformer_xl、xlnet、t5_pegasus模型,prompt、预训练等示例,支持增加embedding输入,EMA策略,修复tokenizer和sinusoid的bug
  • v0.1.5:20220504 增加GAU-alpha,混合梯度,梯度裁剪,单机多卡(DP、DDP)
  • v0.1.4:20220421 增加了VAT,修复了linux下apply_embedding返回项有问题的情况
  • v0.1.3:20220409 初始版本

4.2 版本对应关系

bert4torch版本 torch4keras版本
0.2.7.post2 0.0.6
0.2.7 0.0.6
0.2.6 0.0.5
0.2.5 0.0.4
0.2.4 0.0.3.post2
0.2.3 0.0.2
<0.2.3 ——

5. 更新:

  • 20230426:增加vicuna的集成
  • 20230408:增加苏神Tiger的pytorch实现, 集成苏神、uer的roberta-small/Tiny模型以及ChatYuan v2模型, 增加了对attention_key_size的入参支持,单向decoder模型和encoder decoder模型解码增加cache, 更新fnlp的bart2.0, 增加chatglm-6b预训练模型推理, 集成BELLE_llama模型, 增加量化模块并适配llama,增加skip_init参数加快加载, 增加stream输出/网页demo, 增加ptuning_v2,增加moss模型的int4/int8推理
  • 20230326:增加llama-7b预训练模型, 修改rope为不使用max_position, 增加prompt_clue和nezha_gpt_dialog的finetune示例(skykiseki用户),修复model.half()类型不一致问题,生成式解码新增SeqGeneration和Seq2SeqGeneration, 支持加载多个权重文件, gpt系列默认不加softmax
  • 20230310:增加lion优化器, 修改dp和ddp示例更易用,增加PromptCLUE模型, 修复albert_unshared加载权重, 增加uer-gpt2-chinese预训练模型,修复lm系列(gpt, seq2seq)存在的forward参数不对的问题,修复GlobalPointer使用rope的bug
  • 20230212:兼容accelerate包, 增加ChatYuan v1模型,修复random_sample()的bug
  • 20221230:增加macbert,text2vec-bert-chinese, wobert模型,增加LEAR的ner示例, 增加PGRC、SPN4RE的关系提取示例,transformer.configs支持点操作,可以使用torch4keras的Trainer(net)来初始化, 修复tokenizer的切分subtoken的bug, 允许embedding_size!=hidden_size
  • 20221127:增加deberta_v2模型, 对抗训练从compile转为使用Callback来实现,修复1.7.1版本兼容bug, uie模型内置, 增加triton示例, build_transformer_model需显式指定add_trainer才从BaseModel继承, 增加guwenbert预训练模型,允许position_ids从padding开始
  • 20221102:增加CNN_Nested_NER示例, 删除SpTokenizer基类中的rematch
  • 20221022:修复DP和DDP出现resume_epoch不存在的bug, tokenizer的never_split去除None, transformer_xl的bug, 增加gradient_checkpoint
  • 20221011:虚拟对抗VAT在多个ouput时支持指定,增加elasticsearch示例, 把Trainer抽象到torch4keras中供更多项目使用,把梯度累积移到compile中
  • 20220920:增加TensorRT示例,支持多个schedule(如同时ema+warmup),sanic+onnx部署
  • 20220910:增加默认Logger和Tensorboard日志,ONNX推理,增加ERNIE模型,修复t5的norm_mode问题,允许hidden_size不整除num_attention_heads
  • 20220828:增加nl2sql示例,增加自定义metrics,支持断点续训
  • 20220821:增加W2NER和DiffCSE示例,打印Epoch开始的时间戳,增加parallel_apply,兼容torch<=1.7.1的torch.div无rounding_mode
  • 20220814:增加有监督句向量、关系抽取、文本生成实验指标,兼容torch<1.9.0的缺失take_along_dim,修复bart中位置向量514的问题,修复Sptokenizer对符号不转换
  • 20220727:增加mixup/manifold_mixup/temporal_ensembling策略,修复pgd策略param.grad为空的问题,修改tokenizer支持批量,增加uie示例
  • 20220716:修复原来CRF训练中loss陡增的问题,修复xlnet的token_type_ids输入显存占用大的问题
  • 20220710:增加金融中文FAQ示例,天池新闻分类top1案例,增加EarlyStop,CRF中自带转bool类型
  • 20220629:增加ner的实验,测试crf不同初始化的效果,bert-whitening中文实验
  • 20220613:增加seq2seq+前缀树,增加SimCSE/ESimCSE/PromptBert等无监督语义相似度的中文实验
  • 20220605:增加PromptBert、PET、P-tuning示例,修改tokenizer对special_tokens分词错误的问题,增加t5_pegasus
  • 20220529:transformer_xl、xlnet模型,修改sinusoid位置向量被init_weight的bug,EMA,sohu情感分类示例
  • 20220517:增加预训练代码,支持增加embedding输入(如词性,word粒度embedding)
  • 20220501:增加了混合梯度,梯度裁剪,单机多卡训练(DP、DDP)
  • 20220425:增加了VAT、GAU-alpha等示例,增加了梯度累积,自定义fit()示例
  • 20220415:增加了ner_mrc、ner_span、roformer_v2、roformer-sim等示例
  • 20220405:增加了GPLinker、TPlinker、SimBERT等示例
  • 20220329:增加了CoSENT、R-Drop、UDA等示例
  • 20220322:添加GPT、GPT2、T5模型
  • 20220312:初版提交

6. 预训练权重

模型分类 权重来源 权重链接 备注(若有)
bert 谷歌原版bert(即bert-base-chinese) tftorch tf转pytorch命令转换脚本
bert 哈工大chinese-bert-wwm-ext tf/torchtorch
macbert 哈工大chinese-macbert-base/large tf/torchtorch
roberta 哈工大chinese-roberta-wwm-ext tf/torchtorch
roberta-small/tiny 追一科技 & UER tftorch 转换脚本
deberta_v2 IDEA Erlangshen-DeBERTa-v2 torch 转换脚本
guwenbert 古文bert torch 转换脚本
xlnet 哈工大xlnet tf/torch config
electra 哈工大electra tftorch
macbert 哈工大macbert tftorch
albert brightmart tftorchtorch
ernie 百度文心 paddletorch
roformer 追一科技 tftorch
roformer_v2 追一科技 tftorch
simbert 追一科技 tftorch_base 转换脚本
simbert_v2/roformer-sim 追一科技 tftorch
gau-alpha 追一科技 tf 转换脚本
wobert 追一科技 tftorch_basetorch_plus_base
nezha 华为 tftorch
gpt thu-coai/CDial-GPT torch 转换脚本
gpt2 清华26亿 cmp_lm torch 转换脚本
gpt2 中文GPT2_ML模型 tftorch 转换脚本
gpt2 UER torch 转换脚本
t5 UER torch config
mt5 谷歌 torch config
t5_pegasus 追一科技 tf 转换脚本
bart 复旦 torch, v1.0, v2.0 转换脚本
text2vec text2vec-base-chinese torch
chatyuan v1&v2 clue-ai torch config
PromptCLUE clue-ai torch config
llama facebook torch 转换脚本
vicuna FastChat torch 转换脚本
Belle LianjiaTech torch 合成说明转换脚本
chatglm THUDM torch 转换脚本

7. 鸣谢

  • 感谢苏神实现的bert4keras,本实现有不少地方参考了bert4keras的源码,在此衷心感谢大佬的无私奉献;
  • 其次感谢项目bert4pytorch,也是在该项目的指引下给了我用pytorch来复现bert4keras的想法和思路。

8. 引用

@misc{bert4torch,
  title={bert4torch},
  author={Bo Li},
  year={2022},
  howpublished={\url{https://github.com/Tongjilibo/bert4torch}},
}

9. 其他

  • Wechat Discussions
pic
微信号
pic
微信群
  • Star History Chart
pic
Star History Chart

bert4torch's People

Contributors

skykiseki avatar tongjilibo avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.