Git Product home page Git Product logo

pytorch-oie's Introduction

pytorch-OIE

Open Information Extraction

Requirement

torch
transformers
tqdm
numpy
scikit-learn

Quick Tour

监督的OIE任务被视为一种序列标注任务[典型标注形式],从自然语句中抽取n-ary信息:

2009年11月,奥巴马 将对 ** 进行国事访问。
---------- ------      ---     -------
   ARG3     ARG0       ARG1     PRED

(奥巴马,国事访问,**,2009年11月)

本仓库使用Bert-MLP-CLS作为基线模型进行信息抽取,更多模型选型优先参考引用。

Inference

from package.model.alpha import AlphaModel, AlphaConfig
import torch

config = AlphaConfig(pretrained_model_name_or_path='bert-base-multilingual-cased', pos_embedding_dim=64,
                     fc_1_hidden_size=768, fc_1_dropout_rate=0.3, fc_2_hidden_size=768, fc_2_dropout_rate=0.3,
                     pre_tag_size=3, arg_tag_size=9,)
model = AlphaModel(config).eval()
model.load_state_dict(torch.load('path_to_pth'))
text = ['enter your text']
out = model(text)

Training

Like main_xxxx.py [main_alpha.py]

import torch

from package.dataset.dense import Dataset, DataLoader
from package.model.beta import BetaModel, BetaConfig
from tqdm import tqdm

# definition
lr = 5e-3
num_workers = 0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = BetaConfig(pretrained_model_name_or_path='bert-base-multilingual-cased', pos_embedding_dim=64,
                    fc_1_hidden_size=768, fc_1_dropout_rate=0.2,
                    fc_2_hidden_size=768, fc_2_dropout_rate=0.2,
                    pre_tag_size=3, arg_tag_size=9
                    )
model = BetaModel(config).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

dataset = Dataset('./resource/OIE2016/train.oie.json', model.tokenizer)
dataloader = DataLoader(dataset, batch_size=128, collate_fn=dataset.collate_fn, shuffle=True, num_workers=num_workers)

# training
for i in range(200):
    with tqdm(total=len(dataloader), desc=f'Epoch {i}: Training ...') as t:
        for batch in dataloader:
            optimizer.zero_grad()
    
            input_ids, mask, pre_label_all, pre_label, arg_label = [_.to(device) for _ in batch]
            try:
                loss_pre, loss_arg = model.loss(input_ids, mask, pre_label_all, pre_label, arg_label)
            except:
                # skip noisy data
                continue
            loss = loss_pre + loss_arg
    
            loss.backward()
            optimizer.step()
    
            t.update()
            t.set_postfix(loss=loss.item(), loss_pre=loss_pre.item(), loss_arg=loss_arg.item())
    
    if (i + 1) % 25 == 0:
        torch.save(model.state_dict(), f"model_{i}.pth")

Reference

  • Supervised Open Information Extraction. NAACL-HLT. 2018. [paper]
  • Logician: A Unified End-to-End Neural Approach for Open-Domain Information Extraction. WSDM. 2018. [paper] [zh-dataset-SAOKE]
  • Span Model for Open Information Extraction on Accurate Corpus. AAAI. 2020. [paper] [github]
  • Multi^2OIE: Multilingual Open Information Extraction Based on Multi-Head Attention with BERT. EMNLP. 2020. [paper] [github]

pytorch-oie's People

Contributors

yuto3o avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.