Git Product home page Git Product logo

ner-bert's Introduction

PyTorch solution of NER task with Google AI's BERT model

0. Introduction

This repository contains solution of NER task based on PyTorch reimplementation of Google's TensorFlow repository for the BERT model that was released together with the paper BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.

This implementation can load any pre-trained TensorFlow checkpoint for BERT (in particular Google's pre-trained models) and a conversion script is provided (see below).

1. Loading a TensorFlow checkpoint (e.g. Google's pre-trained models)

You can convert any TensorFlow checkpoint for BERT (in particular the pre-trained models released by Google) in a PyTorch save file by using the convert_tf_checkpoint_to_pytorch.py script.

This script takes as input a TensorFlow checkpoint (three files starting with bert_model.ckpt) and the associated configuration file (bert_config.json), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using torch.load().

You only need to run this conversion script once to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with bert_model.ckpt) but be sure to keep the configuration file (bert_config.json) and the vocabulary file (vocab.txt) as these are needed for the PyTorch model too.

To run this specific conversion script you will need to have TensorFlow and PyTorch installed (pip install tensorflow). The rest of the repository only requires PyTorch.

Here is an example of the conversion process for a pre-trained BERT-Base Uncased model:

export BERT_BASE_DIR=/path/to/bert/multilingual_L-12_H-768_A-12

python3 convert_tf_checkpoint_to_pytorch.py \
    --tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \
    --bert_config_file $BERT_BASE_DIR/bert_config.json \
    --pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin

You can download Google's pre-trained models for the conversion here.

There is used the BERT-Base, Multilingual and BERT-Cased, Multilingual (recommended) in this solution.

2. Results

We didn't search best parametres and obtained the following results for no more than 10 epochs.

Only NER models

Model: BertBiLSTMAttnCRF.

Dataset Lang IOB precision Span precision Total spans in test set Notebook
FactRuEval ru 0.937 0.883 4 factrueval.ipynb
Atis en 0.852 0.787 65 conll-2003.ipynb
Conll-2003 en 0.945 0.858 5 atis.ipynb
  • Factrueval (f1): 0.9163±0.006, best 0.926.
  • Atis (f1): 0.882±0.02, best 0.896
  • Conll-2003 (f1, dev): 0.949±0.002, best 0.951; 0.892 (f1, test).

Model: BertBiLSTMAttnNMT.

Dataset Lang IOB precision Span precision Total spans in test set Notebook
FactRuEval ru 0.925 0.827 4 factrueval-nmt.ipynb
Atis en 0.919 0.829 65 atis-nmt.ipynb
Conll-2003 en 0.936 0.900 5 conll-2003-nmt.ipynb

Joint Models

Model: BertBiLSTMAttnCRFJoint

Dataset Lang IOB precision Span precision Clf precision Total spans in test set Total classes Notebook
Atis en 0.877 0.824 0.894 65 17 atis-joint.ipynb

Model: BertBiLSTMAttnNMTJoint

Dataset Lang IOB precision Span precision Clf precision Total spans in test set Total classes Notebook
Atis en 0.913 0.820 0.888 65 17 atis-joint-nmt.ipynb

Comprasion with ELMo model

We tested BertBiLSTMCRF, BertBiLSTMAttnCRF and BertBiLSTMAttnNMT on russian dataset FactRuEval with freezed ElmoEmbedder:

Model BertBiLSTMCRF:

Dataset Lang IOB precision Span precision Total spans in test set Notebook
FactRuEval ru 0.903 0.851 4 samples.ipynb

Model BertBiLSTMAttnCRF:

Dataset Lang IOB precision Span precision Total spans in test set Notebook
FactRuEval ru 0.899 0.819 4 factrueval.ipynb

Model BertBiLSTMAttnNMT:

Dataset Lang IOB precision Span precision Total spans in test set Notebook
FactRuEval ru 0.902 0.752 4 factrueval-nmt.ipynb

3. Installation, requirements, test

This code was tested on Python 3.5+. The requirements are:

  • PyTorch (>= 0.4.1)
  • tqdm
  • tensorflow (for convertion)

To install the dependencies:

pip install -r ./requirements.txt

PyTorch neural network models

All models are organized as Encoder-Decoder. Encoder is a freezed and weighted (as proposed in elmo) bert output from 12 layers. There are three models that is obtained by using different Decoder.

Encoder: BertBiLSTM

  1. BertBiLSTMCRF: Encoder + Decoder (BiLSTM + CRF)
  2. BertBiLSTMAttnCRF: Encoder + Decoder (BiLSTM + MultiHead Attention + CRF)
  3. BertBiLSTMAttnNMT: Encoder + Decoder (LSTM + Bahdanau Attention - NMT Decode)
  4. BertBiLSTMAttnCRFJoint: Encoder + Decoder (BiLSTM + MultiHead Attention + CRF) + (PoolingLinearClassifier - for classification) - joint model with classification.
  5. BertBiLSTMAttnNMTJoint: Encoder + Decoder (LSTM + Bahdanau Attention - NMT Decode) + (LinearClassifier - for classification) - joint model with classification.

Usage

1. Loading data:

from modules.bert_data import BertNerData as NerData

data = NerData.create(train_path, valid_path, vocab_file)

2. Create model:

from modules.bert_models import BertBiLSTMCRF

model = BertBiLSTMCRF.create(len(data.label2idx), bert_config_file, init_checkpoint_pt, enc_hidden_dim=256)

3. Create learner:

from modules.train import NerLearner

learner = NerLearner(model, data, best_model_path="/datadrive/models/factrueval/exp_final.cpt", lr=0.01, clip=1.0, sup_labels=data.id2label[5:], t_total=num_epochs * len(data.train_dl))

4. Learn your NER model:

learner.fit(2, target_metric='prec')

5. Predict on new data:

from modules.data.bert_data import get_bert_data_loader_for_predict

dl = get_bert_data_loader_for_predict(data_path + "valid.csv", learner)

learner.load_model(best_model_path)

preds = learner.predict(dl)

  • For more detailed instructions of using BERT model see samples.ipynb.
  • For more detailed instructions of using ELMo model see samples.ipynb.

ner-bert's People

Contributors

king-menin avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.