Git Product home page Git Product logo

Comments (9)

iliasmiraoui avatar iliasmiraoui commented on May 19, 2024 4

It works like a charm, Thank you so much !!!

from peft.

iliasmiraoui avatar iliasmiraoui commented on May 19, 2024 4

You are the absolute best for solving this so quickly and providing an example so fast, thank you so much. I just tested it with XXL and it's training very well!

from peft.

younesbelkada avatar younesbelkada commented on May 19, 2024 2

Hi @iliasmiraoui

Can you try to replace:

model.decoder.project_in = lambda x: x.requires_grad_(True)

With

def make_inputs_require_grad(module, input, output):
    output.requires_grad_(True)

model.shared.register_forward_hook(make_inputs_require_grad)

This will be properly addressed to make it more user-friendly in: huggingface/transformers#21598

from peft.

pacman100 avatar pacman100 commented on May 19, 2024 1

Hello @iliasmiraoui, thanks to @younesbelkada, we were able to recreate and debug the issue. For T5 models having been trained using BF16 dtype, following changes are required because load_in_8bit will have certain layers in FP32.

  1. model.enable_input_require_grads() as mentioned above for proper functioning of gradient checkpointing
  2. diff snippet for lm_head casting:
  class CastOutputToFloat(nn.Sequential):
-   def forward(self, x): return super().forward(x).to(torch.float32)
+   def forward(self, x): return super().forward(x.to(torch.float16)).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)
  1. Disable autocast in the train loop. When using trainer, don't use any mixed precision arg such as fp16/bf16 in trainer arguments.

here is a working colab example: https://colab.research.google.com/drive/1M1hFb5Rr_MSDKByHRqMP_9AZQ5TKj--H?usp=sharing

let us know if that resolves the issue

from peft.

younesbelkada avatar younesbelkada commented on May 19, 2024

I can confirm this worked for flan-t5-base:

import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import torch
import torch.nn as nn
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained(
    "google/flan-t5-base", 
    load_in_8bit=True, 
    device_map='auto',
)

tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

for param in model.parameters():
  param.requires_grad = False  # freeze the model - train adapters later
  if param.ndim == 1:
    # cast the small parameters (e.g. layernorm) to fp32 for stability
    param.data = param.data.to(torch.float32)

def make_inputs_require_grad(module, input, output):
    output.requires_grad_(True)


model.gradient_checkpointing_enable()  # reduce number of stored activations
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

class CastOutputToFloat(nn.Sequential):
  def forward(self, x): return super().forward(x).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)

from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)

import transformers
from datasets import load_dataset
data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples['quote']), batched=True)

trainer = transformers.Trainer(
    model=model, 
    train_dataset=data['train'],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=4, 
        gradient_accumulation_steps=4,
        warmup_steps=100, 
        max_steps=200, 
        learning_rate=2e-4, 
        fp16=True,
        logging_steps=1, 
        output_dir='outputs'
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

from peft.

iliasmiraoui avatar iliasmiraoui commented on May 19, 2024

@younesbelkada Surprisingly, it doesn't seem like the model is learning / the loss is converging (compared to the same loop in float32)

from peft.

pacman100 avatar pacman100 commented on May 19, 2024

Hello @iliasmiraoui, could you provide a minimal reproducible script that we can run?

from peft.

younesbelkada avatar younesbelkada commented on May 19, 2024

@iliasmiraoui now you should use model.enable_input_require_grads() , I have updated the Colab notebook accordingly

from peft.

younesbelkada avatar younesbelkada commented on May 19, 2024

Very cool! Don't hesitate to share your adapters on the Hub once trained! 🤗

from peft.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.