Git Product home page Git Product logo

ernie's Introduction


Ernie Logo

Downloads PyPi GitHub releases License

BERT's best friend.

Buy Me A Coffee

Installation

Ernie requires Python 3.6 or higher.

pip install ernie

Open In Colab

Fine-Tuning

Sentence Classification

from ernie import SentenceClassifier, Models
import pandas as pd

tuples = [
    ("This is a positive example. I'm very happy today.", 1),
    ("This is a negative sentence. Everything was wrong today at work.", 0)
]
df = pd.DataFrame(tuples)

classifier = SentenceClassifier(
    model_name=Models.BertBaseUncased,
    max_length=64,
    labels_no=2
)
classifier.load_dataset(df, validation_split=0.2)
classifier.fine_tune(
    epochs=4,
    learning_rate=2e-5,
    training_batch_size=32,
    validation_batch_size=64
)

Prediction

Predict a single text

text = "Oh, that's great!"

# It returns a tuple with the prediction
probabilities = classifier.predict_one(text)

Predict multiple texts

texts = ["Oh, that's great!", "That's really bad"]

# It returns a generator of tuples with the predictions
probabilities = classifier.predict(texts)

Prediction Strategies

If the length in tokens of the texts is greater than the max_length with which the model has been fine-tuned, they will be truncated. To avoid losing information you can use a split strategy and aggregate the predictions in different ways.

Split Strategies

  • SentencesWithoutUrls. The text will be splitted in sentences.
  • GroupedSentencesWithoutUrls. The text will be splitted in groups of sentences with a length in tokens similar to max_length.

Aggregation Strategies

  • Mean: the prediction of the text will be the mean of the predictions of the splits.
  • MeanTopFiveBinaryClassification: the mean is computed over the 5 higher predictions only.
  • MeanTopTenBinaryClassification: the mean is computed over the 10 higher predictions only.
  • MeanTopFifteenBinaryClassification: the mean is computed over the 15 higher predictions only.
  • MeanTopTwentyBinaryClassification: the mean is computed over the 20 higher predictions only.
from ernie import SplitStrategies, AggregationStrategies

texts = ["Oh, that's great!", "That's really bad"]
probabilities = classifier.predict(
    texts,
    split_strategy=SplitStrategies.GroupedSentencesWithoutUrls,
    aggregation_strategy=AggregationStrategies.Mean
) 

You can define your custom strategies through AggregationStrategy and SplitStrategy classes.

from ernie import SplitStrategy, AggregationStrategy

my_split_strategy = SplitStrategy(
    split_patterns: list,
    remove_patterns: list,
    remove_too_short_groups: bool,
    group_splits: bool
)
my_aggregation_strategy = AggregationStrategy(
    method: function,
    max_items: int,
    top_items: bool,
    sorting_class_index: int
)

Save and restore a fine-tuned model

Save model

classifier.dump('./model')

Load model

classifier = SentenceClassifier(model_path='./model')

Interrupted Training

Since the execution may break during training (especially if you are using Google Colab), you can opt to secure every new trained epoch, so the training can be resumed without losing all the progress.

classifier = SentenceClassifier(
    model_name=Models.BertBaseUncased,
    max_length=64
)
classifier.load_dataset(df, validation_split=0.2)

for epoch in range(1, 5):
    if epoch == 3:
        raise Exception("Forced crash")

    classifier.fine_tune(epochs=1)
    classifier.dump(f'./my-model/{epoch}')
last_training_epoch = 2

classifier = SentenceClassifier(model_path=f'./my-model/{last_training_epoch}')
classifier.load_dataset(df, validation_split=0.2)

for epoch in range(last_training_epoch + 1, 5):
    classifier.fine_tune(epochs=1)
    classifier.dump(f'./my-model/{epoch}')

Autosave

Even if you do not explicitly dump the model, it will be autosaved into ./ernie-autosave every time fine_tune is successfully executed.

ernie-autosave/
└── model_family/
    └── timestamp/
        ├── config.json
        ├── special_tokens_map.json
        ├── tf_model.h5
        ├── tokenizer_config.json
        └── vocab.txt

You can easily clean the autosaved models by invoking clean_autosave after finishing a session or when starting a new one.

from ernie import clean_autosave
clean_autosave()

Supported Models

You can access some of the official base model names through the Models class. However, you can directly type the HuggingFace's model name such as bert-base-uncased or bert-base-chinese when instantiating a SentenceClassifier.

See all the available models at huggingface.co/models.

BERT

  • BertBaseUncased
  • BertBaseCased
  • BertLargeUncased
  • BertLargeCased

RoBERTa

  • RobertaBaseCased
  • RobertaLargeCased

XLNet

  • XLNetBaseCased
  • XLNetLargeCased

DistilBERT

  • DistilBertBaseUncased
  • DistilBertBaseMultilingualCased

ALBERT

  • AlbertBaseCased
  • AlbertLargeCased
  • AlbertXLargeCased
  • AlbertXXLargeCased
  • AlbertBaseCased2
  • AlbertLargeCased2
  • AlbertXLargeCased2
  • AlbertXXLargeCased2

Additional Info

Accesing the model and tokenizer

You can directly access both the model and tokenizer objects once the classifier has been instantiated:

classifier.model
classifier.tokenizer

Keras model.fit arguments

You can pass Keras arguments of the model.fit method to the classifier.fine_tune method. For example:

classifier.fine_tune(class_weight={0: 0.2, 1: 0.8})

Sponsors

Sticker Mule Logo

ernie's People

Contributors

brunneis avatar marcosfp97 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.