Git Product home page Git Product logo

textclassifier_transformer's Introduction

TextClassifier_Transformer

个人基于谷歌开源的BERT编写的文本分类器(基于微调方式),可自由加载NLP领域知名的预训练语言模型BERT、 Roberta、ALBert及其wwm版本,同时适配ERNIE1.0.
该项目支持两种预测方式:
(1)线下实时预测
(2)服务端实时预测

新增改动

2020-03-25:
(1)项目名由'TextClassifier_BERT'更改为'TextClassifier_Transformer';
(2)新增ELECTRA、AlBert两个预训练模型。
注意:在使用AlBert时,请将该项目下的modeling.py文件更新为ALBert项目中下的modeling.py,而后在运行
2020-03-04:
模型部署增加tf-serving机制,具体实施方式见This Blog

运行环境

  • Python3.6+
  • Tensorflow1.10+/Tensorflow-gpu1.10+

提供知名的预训练语言模型下载地址(其中百度开源的Ernie模型已转换成tf格式):
Bert-base:链接:https://pan.baidu.com/s/18h9zgvnlU5ztwaQNnzBXTg 提取码:9r1z
Roberta:链接:https://pan.baidu.com/s/1FBpok7U9ekYJRu1a8NSM-Q 提取码:i50r
Bert-wwm:链接:链接:https://pan.baidu.com/s/1lhoJCT_LkboC1_1YXk1ItQ 提取码:ejt7
ERNIE1.0:链接:链接:https://pan.baidu.com/s/1S6MI8rQyQ4U7dLszyb73Yw 提取码:gc6f
ELECTRA-Tiny:链接:https://pan.baidu.com/s/11QaL7A4YSCYq4YlGyU1_vA 提取码:27jb
AlBert-base:链接:https://pan.baidu.com/s/1U7Zx73ngci2Oqp3SLaVOaw 提取码:uijw

项目说明

主要分为两种运行模式:
模式1:线下实时预测
step1:数据准备
step2:模型训练
step3:模型导出
step4:线下实时预测
模式2:服务端实时预测 step1:数据准备
step2:模型训练
step3:模型转换
step4:服务部署
step5:应用端

注意事项

1.如果你只是想体验从模型训练到本地线下预测这一套流程,只需要按照模式1依次执行即可
2.若你想想体验从模型训练到模型部署整个流程,则需要按照模式2依次执行

下面将针对以上两个模式的运行方式进行详细说明。

模式1:线下实时预测

Step1:数据准备

为了快速实验项目效果,这里使用了样本规模较小的手机评论数据,数据比较简单,有三个分类:-1(差评)、0(中评)、1(好评),数据样例如下所示:
数据描述 ps:本项目中已将其拆分成了train.tsv、dev.txv、test.tsv三个文件

Step2:模型训练

训练命令:

bash train.sh

train.sh参数说明:

export BERT_BASE_DIR=./chinese_roberta_zh_l12 #指定预训练的语言模型所在路径
export DATA_DIR=./dat #指定数据集所在路径
export TRAINED_CLASSIFIER=./output #训练的模型输出路径
export MODEL_NAME=mobile_0_roberta_base #训练的模型命名

详细说明:训练模型直接使用bert微调的方式进行训练,对应的程序文件为run_classifier_serving.py。关于微调bert进行训练的代码网上介绍的 很多,这里就不一一介绍。主要是创建针对该任务的Processor即:SentimentProcessor,在这个processor的_create_examples()和get_labels()函数自定义,如下所示:

class SetimentProcessor(DataProcessor):
  def get_train_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

  def get_dev_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

  def get_test_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

  def get_labels(self):
    """See base class."""

    """
    if not os.path.exists(os.path.join(FLAGS.output_dir, 'label_list.pkl')):
        with codecs.open(os.path.join(FLAGS.output_dir, 'label_list.pkl'), 'wb') as fd:
            pickle.dump(self.labels, fd)
    """
    return ["-1", "0", "1"]

  def _create_examples(self, lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    for (i, line) in enumerate(lines):
      if i == 0: 
        continue
      guid = "%s-%s" % (set_type, i)

      #debug (by xmxoxo)
      #print("read line: No.%d" % i)

      text_a = tokenization.convert_to_unicode(line[1])
      if set_type == "test":
        label = "0"
      else:
        label = tokenization.convert_to_unicode(line[0])
      examples.append(
          InputExample(guid=guid, text_a=text_a, label=label))
    return examples


注意,此处作出的一个特别变动之处是在conver_single_example()函数中增加了一段保存label的代码,在训练过程中在保存的模型路径下生成label2id.pkl文件,代码如下所示:

#--- save label2id.pkl ---
#在这里输出label2id.pkl , add by stephen 2019-10-12
output_label2id_file = os.path.join(FLAGS.output_dir, "label2id.pkl")
if not os.path.exists(output_label2id_file):
   with open(output_label2id_file,'wb') as w:
      pickle.dump(label_map,w)
#--- Add end ---

Step3:模型导出

运行如下命令:

bash export.sh

export.sh参数说明:

#以下四个参数应与train.sh中设置的值保持一致
export BERT_BASE_DIR=./chinese_roberta_zh_l12
export DATA_DIR=./dat
export TRAINED_CLASSIFIER=./output
export MODEL_NAME=mobile_0_roberta_base

会在指定的exported目录下生成以一个时间戳命名的模型目录。
详细说明:run_classifier.py 主要设计为单次运行的目的,如果把 do_predict 参数设置成 True,倒也确实可以预测,但输入样本是基于文件的,并且不支持将模型持久化在内存里进行 serving,因此需要自己改一些代码,达到两个目的:
(1)允许将模型加载到内存里,即:允许一次加载,多次调用。
(2)允许读取非文件中的样本进行预测。譬如从标准输入流读取样本输入。

def serving_input_fn():
    label_ids = tf.placeholder(tf.int32, [None], name='label_ids')
    input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')
    input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')
    segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
        'label_ids': label_ids,
        'input_ids': input_ids,
        'input_mask': input_mask,
        'segment_ids': segment_ids,
    })
    return input_fn

继而在run_classifier_serving中定义do_export选项:

if do_export:
   estimator._export_to_tpu = False
   estimator.export_savedmodel(Flags.export_dir, serving_input_fn)

Step4:线下实时预测

运行test_serving.py文件,即可进行线下实时预测。
运行效果如下所示:
运行效果图
详细说明:导出模型后,就不需要第 859 行那个 estimator 对象了,可以自行从刚刚的导出模型目录加载模型,代码如下:

predict_fn = tf.contrib.predictor.from_saved_model('/exported/1571054350')

基于上面的 predict_fn 变量,就可以直接进行预测了。下面是一个从标准输入流读取问题样本,并预测分类的样例代码:

while True:
    question = input("> ")
    predict_example = InputExample("id", question, None, '某固定伪标记')
    feature = convert_single_example(100, predict_example, label_list,
                                        FLAGS.max_seq_length, tokenizer)
 
    prediction = predict_fn({
        "input_ids":[feature.input_ids],
        "input_mask":[feature.input_mask],
        "segment_ids":[feature.segment_ids],
        "label_ids":[feature.label_id],
    })
    probabilities = prediction["probabilities"]
    label = label_list[probabilities.argmax()]
    print(label)

模式2:服务端实时预测

首先针对该模式的基本架构进行说明:
服务端部署架构
架构说明:
BERT模型服务端:加载模型,进行实时预测的服务; 使用的是 BERT-BiLSTM-CRF-NER提供的bert-base;
API服务端:调用实时预测服务,为应用提供API接口的服务,用flask编写;
应用端:最终的应用端; 我这里为了简便,并没有编写网页,直接调用了api接口。

Step1:数据准备

同模式1中的Step1介绍。

Step2:模型训练

同模式1中的Step2介绍。

Step3:模型转换

运行如下命令:

bash model_convert.sh

会在$TRAINED_CLASSIFIER/$EXP_NAME生成pb格式的模型文件
model_convert.sh参数说明:

export BERT_BASE_DIR=./chinese_roberta_zh_l12 #训练模型时使用的预训练语言模型所在路径
export TRAINED_CLASSIFIER=./output #训练好的模型输出的路径
export EXP_NAME=mobile_0_roberta_base #训练后保存的模型命名

python freeze_graph.py \
    -bert_model_dir $BERT_BASE_DIR \
    -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \
    -max_seq_len 128 #注意,这里的max_seq_len应与训练的脚本train.sh设置的max_seq_length参数值保持一致

Step4:模型部署

运行如下命令:

bash bert_classify_server.sh 

提示:在运行这个命令前需要保证安装了bert-base这个库,使用如下命令进行安装:

pip install bert-base

注意
port 和 port_out 这两个参数是API调用的端口号,默认是5555和5556,如果你准备部署多个模型服务实例,那一定要指定自己的端口号,避免冲突。 我这里是改为: 5575 和 5576
如果报错没运行起来,可能是有些模块没装上,都是 bert_base/server/http.py里引用的,装上就好了:

sudo pip install flask 
sudo pip install flask_compress
sudo pip install flask_cors
sudo pip install flask_json

我这里的配置是2个GTX 1080 Ti,这个时候双卡的优势终于发挥出来了,GPU 1用于预测,GPU 0还可以继续训练模型。
部署成功示例图如下:
部署成功示例图

Step5:应用端

运行如下命令:

python api/api_service_flask.py

即可通过指定api接口(本项目中是http://192.168.9.23:8910/predict_online?text=我好开心)访问部署的服务器。
通过浏览器进行请求:
浏览器请求

textclassifier_transformer's People

Contributors

vincent131499 avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.