Git Product home page Git Product logo

ratransformers's Introduction

Joao Lages' LinkedIn profile ย  Joao Lages' Medium articles

Hi there ๐Ÿ‘‹

ratransformers's People

Contributors

joaolages avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

ratransformers's Issues

Errors when using multiple GPUs

Hello, I was trying to train the model on two GPUs using your example code but got the following errors. Could you please point me to the possible locations that caused this error? It looks like somewhere, the input to the model was moved to the second GPU. In trainer, the model and data are placed on the first GPU by default. Thank you!

Traceback (most recent call last):
  File "/home/python3.8/site-packages/transformers/trainer.py", line 2345, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/python3.8/site-packages/transformers/trainer.py", line 2377, in compute_loss
    outputs = model(**inputs)
  File "/home/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/home/python3.8/site-packages/torch/_utils.py", line 461, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/python3.8/site-packages/ratransformers/__init__.py", line 193, in run
    return function(*args, **kwargs)
  File "/home/python3.8/site-packages/ratransformers/t5.py", line 114, in forward
    encoder_outputs = self.encoder(
  File "/home/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/python3.8/site-packages/ratransformers/t5.py", line 422, in forward
    inputs_embeds = self.embed_tokens(input_ids)
  File "/home/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/python3.8/site-packages/torch/nn/modules/sparse.py", line 158, in forward
    return F.embedding(
  File "/home/python3.8/site-packages/torch/nn/functional.py", line 2199, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper__index_select)

The example code I'm using is the following:

import json
from transformers import AutoModelForSeq2SeqLM
import ratransformers
import torch
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, EarlyStoppingCallback
from collections import defaultdict


class Text2SQLDataset(torch.utils.data.Dataset):
    def __init__(self, X, y, tokenizer, X_word_relations=None):
        self.X = X
        self.y = y
        self.X_word_relations = X_word_relations or [None] * len(X)
        self.tokenizer = tokenizer

    def __getitem__(self, index: int) -> dict:
        source = self.tokenizer(self.X[index], padding='max_length', input_relations=self.X_word_relations[index],
                                return_tensors="pt")
        target = self.tokenizer(self.y[index], padding='max_length', input_relations=None, return_tensors="pt")

        source_ids = source["input_ids"].squeeze()
        source_input_relations = source["input_relations"].squeeze()
        target_ids = target["input_ids"].squeeze()
        target_ids[target_ids == 0] = -100

        src_mask = source["attention_mask"].squeeze()
        target_mask = target["attention_mask"].squeeze()

        return {
            "input_ids": source_ids.to('cuda:0'),
            "attention_mask": src_mask.to('cuda:0'),
            "label": target_ids.to('cuda:0'),
            "decoder_attention_mask": target_mask.to('cuda:0'),
            'input_relations': source_input_relations.to('cuda:0')
        }

    def __len__(self):
        return len(self.X)


def get_processed_data(raw_data):
    X, y, X_word_relations = [], [], []
    n_skip = 0
    for d in raw_data:
        input_text = d['question'] + f" | {d['db_id']}"

        word_relations = defaultdict(dict)

        table_span, table_i = None, None
        for i, c_name in tables[d['db_id']]['column_names_original']:
            if i < 0: continue
            if table_i != i:
                table_i = i
                table_span = (
                    len(input_text + ' | '),
                    len(input_text + ' | ') + len(tables[d['db_id']]['table_names_original'][i]))
                input_text += f" | {tables[d['db_id']]['table_names_original'][i]} : "

                c_span = (len(input_text), len(input_text) + len(c_name))
                input_text += c_name

            else:
                c_span = (len(input_text + ', '), len(input_text + ', ') + len(c_name))
                input_text += f', {c_name}'

            word_relations[table_span][c_span] = 'table_column_link'
            word_relations[c_span][table_span] = 'column_table_link'

        if len(input_text.split()) > 200:
            # Skipped sample with too long input
            n_skip += 1
            continue

        X.append(input_text.lower())
        y.append((d['db_id'] + ' | ' + d['query']).lower())
        X_word_relations.append(word_relations)

    return X, y, X_word_relations, n_skip


with open('spider/tables.json') as fp:
    tables = {t['db_id']: t for t in json.load(fp)}

with open('spider/train_spider.json') as fp:
    train_data = json.load(fp)

with open('spider/train_others.json') as fp:
    train_data += json.load(fp)

with open('spider/dev.json') as fp:
    test_data = json.load(fp)

train_X, train_y, train_X_word_relations, n_skip = get_processed_data(train_data)
print("Train:", len(train_X), f" Skipped {n_skip} samples with too long input.")
test_X, test_y, test_X_word_relations, n_skip = get_processed_data(test_data)
print("Test:", len(test_X), f" Skipped {n_skip} samples with too long input.")

ratransformer = ratransformers.RATransformer(
    'tscholak/1zha5ono',
    relation_kinds=['table_column_link', 'column_table_link'],
    model_cls=AutoModelForSeq2SeqLM
)
model = ratransformer.model
tokenizer = ratransformer.tokenizer
# Get datasets with word relations
train_d = Text2SQLDataset(train_X, train_y, tokenizer, train_X_word_relations)
val_d = Text2SQLDataset(test_X, test_y, tokenizer, test_X_word_relations)

# Get datasets without word relations
train_d_without_relations = Text2SQLDataset(train_X, train_y, tokenizer)
val_d_without_relations = Text2SQLDataset(test_X, test_y, tokenizer)

# Set training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir='checkpoints',
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    per_device_eval_batch_size=4,
    max_steps=100000,
    eval_steps=1000,
    seed=42,
    save_total_limit=1,
    predict_with_generate=True,
)

# Set trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_d,
    eval_dataset=val_d,
    tokenizer=tokenizer,
)

# # get performance before training
# trainer.evaluate()

# train until early stopping
trainer.train()

# get performance after training
trainer.evaluate()

# Save model
trainer.save_model('ra-tscholak/1zha5ono')

# Reload model again
ratransformer = ratransformers.RATransformer(
    'ra-tscholak/1zha5ono',
    relation_kinds=['table_column_link', 'column_table_link'],
    alias_model_name='t5'
)
model = ratransformer.model
tokenizer = ratransformer.tokenizer

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_d,
    eval_dataset=val_d,
    tokenizer=tokenizer,
    callbacks=[EarlyStoppingCallback()]
)
trainer.evaluate()

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.