Git Product home page Git Product logo

cpm-lm-tf2's Introduction

TensorFlow 2.x CPM-Generate

本Repo将模型转换为TensorFlow版本,原Repo https://github.com/TsinghuaAI/CPM-Generate

原项目首页:https://cpm.baai.ac.cn/

原项目介绍文章:https://mp.weixin.qq.com/s/oI2Ak-M57MSuycLVpVEiHw

如果你只想大概看一下结果,请直接打开`prediction_large.ipynb`文件预览

感谢智源研究院的工作!

^_^ 如果你喜欢我的工作,请给智源研究院的原Repo打星,如果能顺便给我的Repo也打一个就更好了。

使用方法

HINT:请主要使用TensorFlow 2.3.0以上版本测试,其他版本可能出兼容性问题,本项目以学习为主,以后也不会做太多兼容性优化请见谅

  1. Clone本Repo

  2. 下载模型:


百度网盘,下载大模型 cpm-large-tf2
链接: https://pan.baidu.com/s/1gup1qhojFr4jC_a70tlmgw  密码: 5lba

百度网盘,下载小模型 cpm-distill-tf2
链接: https://pan.baidu.com/s/10vhjVRX2tWbX2892ulLemg  密码: dsvv

or GDrive:

大模型
https://drive.google.com/drive/folders/1XGy2B6QSf1k0SOtVF13gcekLvdtpPkmq?usp=sharing

小模型
https://drive.google.com/drive/folders/13wTPVMEslAx8Xl0Sfw59BbSUR-BowHgW?usp=sharing

下载cpm-large-tf2,或者下载小模型cpm-distill-tf2,自己选一个就好,结果上肯定是大模型更好,具体区别请自己对比prediction_large.ipynbprediction_distill.ipynb中的结果

下载到Clone好的Repo目录,结构大概是这样:

cpm-lm-tf2/
...cpm-large-tf2/  (从网盘下载好的TF2版本大模型)
......assets
......saved_model.pb
......variables
...cpm-distill-tf2/  (从网盘下载好的TF2版本小模型,大小两个下载其中一个就好)
......assets
......saved_model.pb
......variables
...CPM-Generate/
......bpe_3w_new/ (词表所在目录)
...prediction_large.ipynb  (预测大模型的demo主程序,下载了大模型就运行这个)
...prediction_distill.ipynb  (预测小模型的demo主程序,下载了小模型就运行这个)
...gpt2_tokenizer.py  (分词文件,这个里面引入了jieba,和huggingface那一系列的不能简单互换)
运行所需的代码其实就大概以上的几个文件和目录就够了,其他的主要是模型转换等代码
  1. 安装依赖
# 依赖:
pip install sentencepiece jieba regex tensorflow tensorflow-hub
  1. 参考prediction_large.ipynb中的代码运行,或参考小模型的perdiction_distill.ipynb

TensorFlow版本和原版本的区别

道理来讲应该没有什么太大区别,而且也载入了原来的参数,不过毕竟还是有GPU -> CPU,PyTorch -> TensorFlow这样的转换,所以可能和原模型结果有一定出入,不过笔者估计这个出入不会很大,顶多1%左右。

引用

参考原Repo

@article{cpm-v1,
  title={CPM: A Large-scale Generative Chinese Pre-trained Language Model},
  author={Zhang, Zhengyan and Han, Xu, and Zhou, Hao, and Ke, Pei, and Gu, Yuxian and Ye, Deming and Qin, Yujia and Su, Yusheng and Ji, Haozhe and Guan, Jian and Qi, Fanchao and Wang, Xiaozhi and Zheng, Yanan and Cao, Jiannan and Zeng, Guoyang and Cao, Huanqi and Chen, Shengqi and Li, Daixuan and Sun, Zhenbo and Liu, Zhiyuan and Huang, Minlie and Han, Wentao and Tang, Jie and Li, Juanzi and Sun, Maosong},
  year={2020}
}

cpm-lm-tf2's People

Contributors

qhduan avatar xingyaoww avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

cpm-lm-tf2's Issues

请问如何实现由键盘循环输入测试内容?

作者你好,非常感谢您这一版本,让我们单卡/无卡也能进行测试并从中寻找灵感。
有一点遗憾,每次换内容的时候,需要重新运行(修改sample里面的内容),等待漫长的加载时间。
我参考了“GPT2-ML”这一项目的demo.py,利用input() 实现由键盘获取新的测试内容(无需重新运行py),可惜一直失败。
希望能得到您的指点!
谢谢!

将pytorch模型转化到tf,效果变差

您好,如题按照您的脚本,对distill模型进行转化,pytorch到tf,怎么效果变差很多,您知道可能哪里出问题的吗?

脚本除了仍使用GPU加载模型,未作任何改变

Loading model cost 0.702 seconds.
Prefix dict has been built successfully.
(1, 1, 30000) (12, 1, 2, 12, 1, 64)
(1, 1, 30000) (12, 1, 2, 12, 1, 64)
tf.Tensor(
[[  837   259   497   788 22707 22707 22707 22707 22707 22707 22707 22707
  22707 22707 22707 22707 22707 22707 22707 22707]], shape=(1, 20), dtype=int64)
今天天气 不错 猥亵 猥亵 猥亵 猥亵 猥亵 猥亵 猥亵 猥亵 猥亵 猥亵 猥亵 猥亵 猥亵 猥亵 猥亵 猥亵
tf.Tensor(
[[  837   259   497   788 24672  6655  7254  6123 22707  2779  8494 28689
  20220 28689  2779  2779 28689 22707  2779  5469]], shape=(1, 20), dtype=int64)
今天天气 不错 裁定穷 驾清楚 猥亵脑 10000畫 电饭畫脑脑畫 猥亵脑 一方

如何训练他的model

感谢您的付出,让我们用tensorflow-cpu也能运行
但是,如果我想对目前与训练模型做一些针对性训练,应该怎么去做呢?
hub.load()加载出模型gpt之后,怎样去做一些训练?
希望得到您的解答
谢谢

生成文本太慢

你好,该项目生成文本实在有点慢,平均2.6s一条数据,使用的是GPU机器,请问有什么方法可以加速吗?

使用报错

你好,当我在使用模型的时候(prediction_v2文件),按照给的例子,输入
ret = sample(tokenizer, gpt, '''匆匆瞥了金威南一眼,深邃的眼眸有一层让人沉溺的雾气,她敛眸,跟他错开。''', 3, 3000, top_p=0.9, temperature=0.9) for x in ret: print(x) print('-' * 20)
运行后,系统提示报错:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [1,1,0,0] vs. [3,32,0,1024]
[[{{node StatefulPartitionedCall/while/body/_1043/while/StatefulPartitionedCall/gpt/layer00/attention/mul_2}}]] [Op:__inference_signature_wrapper_100226]

Function call stack:
signature_wrapper
可以问下,是什么导致的不~~

如何计算perplexity(困惑度)

请问在这个项目下如何计算perplexity(困惑度)?我有一点疑惑是在于,保存的tensorflow_hub模型似乎是专为文本生成定制的(在tf2gpt/loading.ipynb中看到的相关代码)。我想使用你给出的模型参数,但是做一个前向传播,我所需要用到的是['output0']之前的一层(未经argmax),然后与label做一下交叉熵损失计算,从而算出perplexity。请问你有什么建议吗?感谢!

您代码中是这样做文本生成的:
ret = gpt.signatures['serving_default'](
inp=inputs,
length=length,
top_p=tf.constant(top_p, tf.float32),
temperature=tf.constant(temperature, tf.float32)
)['output_0']

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.