Git Product home page Git Product logo

smart-pytorch's Introduction

SMART - PyTorch

A PyTorch implementation of SMART, a regularization technique to fine-tune pretrained (language) models. You might also be interested in vat-pytorch, a more generic collection of virtual adversarial training (VAT) methods, in PyTorch.

Install

$ pip install smart-pytorch

PyPI - Python Version

Usage

Minimal Example

import torch
import torch.nn as nn
from smart_pytorch import SMARTLoss

# Define function that will be perturbed (usually our network)
eval_fn = torch.nn.Linear(in_features=10, out_features=20)

# Define loss function between states 
loss_fn = nn.MSELoss()

# Initialize regularization loss
regularizer = SMARTLoss(eval_fn = eval_fn, loss_fn = loss_fn)

# Compute initial input embed and output state 
embed = torch.rand(1, 10) # [batch_size, in_features]
state = eval_fn(embed) # [batch_size, out_featueres]

# Compute regularation loss 
loss = regularizer(embed, state)
loss # tensor(0.0922578126, grad_fn=<MseLossBackward0>)

Where eval_fn is a function (usually a neural network) that takes as input an embedding embed and produces as output one or multiple states state. Internally, this function is used to perturb the input embed with noise to get a perturbed state which is compared with the initially provided state.

Full API Example

import torch
import torch.nn as nn
from smart_pytorch import SMARTLoss

# Define function that will be perturbed (usually our network)
eval_fn = torch.nn.Linear(in_features=10, out_features=20)

# Define loss function between states 
loss_fn = nn.MSELoss()

# Norm used to normalize the gradient 
inf_norm = lambda x: torch.norm(x, p=float('inf'), dim=-1, keepdim=True)

# Initialize regularization loss
regularizer = SMARTLoss(
    eval_fn = eval_fn,      
    loss_fn = loss_fn,      # Loss to apply between perturbed and true state 
    loss_last_fn = loss_fn, # Loss to apply between perturbed and true state on the last iteration (default = loss_fn)
    norm_fn = inf_norm,     # Norm used to normalize the gradient (default = inf_norm)
    num_steps = 1,          # Number of optimization steps to find noise (default = 1)
    step_size = 1e-3,       # Step size to improve noise (default = 1e-3)
    epsilon = 1e-6,         # Noise norm constraint (default = 1e-6)
    noise_var = 1e-5        # Initial noise variance (default = 1e-5)
)

# Compute initial input embed and output state 
embed = torch.rand(1, 10) # [batch_size, in_features]
state = eval_fn(embed) # [batch_size, out_featueres]

# Compute regularation loss 
loss = regularizer(embed, state)
loss # tensor(0.0432184562, grad_fn=<MseLossBackward0>)

RoBERTa Classification Example

This example demostrates how to wrap a RoBERTa classifier from Huggingface to use with SMART.

from smart_pytorch import SMARTLoss, kl_loss, sym_kl_loss
from transformers import AutoTokenizer, AutoModelForSequenceClassification

class SMARTRobertaClassificationModel(nn.Module):
    
    def __init__(self, model, weight = 0.02):
        super().__init__()
        self.model = model 
        self.weight = weight

    def forward(self, input_ids, attention_mask, labels):

        # Get initial embeddings 
        embed = self.model.roberta.embeddings(input_ids) 

        # Define eval function 
        def eval(embed):
            outputs = self.model.roberta(inputs_embeds=embed, attention_mask=attention_mask)
            pooled = outputs[0] 
            logits = self.model.classifier(pooled) 
            return logits 
        
        # Define SMART loss
        smart_loss_fn = SMARTLoss(eval_fn = eval, loss_fn = kl_loss, loss_last_fn = sym_kl_loss)
        # Compute initial (unperturbed) state 
        state = eval(embed)
        # Apply classification loss 
        loss = F.cross_entropy(state.view(-1, 2), labels.view(-1))
        # Apply smart loss 
        loss += self.weight * smart_loss_fn(embed, state)
        
        return state, loss
    

tokenizer = AutoTokenizer.from_pretrained('roberta-base')
model = AutoModelForSequenceClassification.from_pretrained('roberta-base')  

model_smart = SMARTRobertaClassificationModel(model)
# Compute inputs 
text = ["This text belongs to class 1...", "This text belongs to class 0..."]
inputs = tokenizer(text, return_tensors='pt')
labels = torch.tensor([1, 0]) 

# Compute output and loss 
state, loss = model_smart(input_ids = inputs['input_ids'], attention_mask = inputs['attention_mask'], labels = labels)
print(state.shape, loss) # torch.Size([2, 2]) tensor(0.6980957389, grad_fn=<AddBackward0>)

Citations

@inproceedings{Jiang2020SMARTRA,
  title={SMART: Robust and Efficient Fine-Tuning for Pre-trained Natural Language Models through Principled Regularized Optimization},
  author={Haoming Jiang and Pengcheng He and Weizhu Chen and Xiaodong Liu and Jianfeng Gao and Tuo Zhao},
  booktitle={ACL},
  year={2020}
}

smart-pytorch's People

Contributors

flavioschneider avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.