Git Product home page Git Product logo

albert_lstm_crf_ner's Introduction

Albert+BI-LSTM+CRF的实体识别 Pytorch

outline

lstm_crf的模型结构

lstm_crf

albert_lstm的模型结构

albert_embedding_lstm

1.这里将每个句子split成一个个字token,将每个token映射成一个数字,再加入masks,然后输入给albert产生句子矩阵表示,比如一个batch=10,句子最大长度为126,加上首尾标志[CLS]和[SEP],max_length=128,albert_base_zh模型输出的数据shape为(batch,max_length,hidden_states)=(10,128,768)。

2.利用albert产生的表示作为lstm的embedding层。

3.没有对albert进行fine-tune。

train

setp 1: 利用albert/tfmodel_2_pymodel.py

1.将tensorflow预训练模型转化为pytorch可用模型。

2.本程序使用albert_base_zh(小模型体验版), 参数量12M, 层数12,大小为40M。

3.转为pytorch模型后放在albert/pretrain/pytorch目录下。

4.模型的参数见albert/configs/目录。

setp 2: 部分参数设置 models/config.yml

embedding_size: 768
hidden_size: 128
model_path: models/
batch_size: 64
max_length: 128
dropout: 0.5
tags:
	- ORG
	- PER
	- LOC
	- T

step 3: train

python main.py train
训练数据来自人民日报的标注数据

evaluate

> epoch [0] |██                       | 395/4473
  loss 0.07
  epoch [0] |██                       | 396/4473
  loss 0.06
  epoch [0] |██                       | 397/4473
  loss 0.06
  epoch [0] |██                       | 398/4473
  loss 0.06
  epoch [0] |██                       | 399/4473
  loss 0.06
  epoch [0] |██                       | 400/4473
  loss 0.05
  eval
        ORG	recall 1.00	precision 1.00	f1 1.00
        PER	recall 0.97	precision 0.96	f1 0.96
        LOC	recall 1.00	precision 1.00	f1 1.00
        T	recall 0.84	precision 0.80	f1 0.82

predict

python main.py predict
input text:“刘老根大舞台”被文化部、国家旅游局联合评为首批“国家文化旅游重点项目”

note

在src/lstm_crf的model.py中

a.albert的预训练模型作为embedding层

> bert_config =BertConfig.from_pretrained(str(config['albert_config_path']), share_type='all')
  self.word_embeddings = BertModel.from_pretrained(config['bert_dir'], config=bert_config)
  self.word_embeddings.to(DEVICE)
  self.word_embeddings.eval()

b.embedding的输出是(batch_size, seq_len, embedding_dim)

> with torch.no_grad():
        embeddings = self.word_embeddings(input_ids=sentence, attention_mask=mask)
        #因为在albert中的config中设置了"output_hidden_states":"True","output_attentions":"True",所以返回所有层
        #也可以只返回最后一层
        all_hidden_states, all_attentions = embeddings[-2:]  # 这里获取所有层的hidden_satates以及attentions
        embeddings = all_hidden_states[-2]  # 倒数第二层hidden_states

REFERENCES

albert_lstm_crf_ner's People

Contributors

jiangnanboy avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.