Comments (9)
It works like a charm, Thank you so much !!!
from peft.
You are the absolute best for solving this so quickly and providing an example so fast, thank you so much. I just tested it with XXL and it's training very well!
from peft.
Can you try to replace:
model.decoder.project_in = lambda x: x.requires_grad_(True)
With
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.shared.register_forward_hook(make_inputs_require_grad)
This will be properly addressed to make it more user-friendly in: huggingface/transformers#21598
from peft.
Hello @iliasmiraoui, thanks to @younesbelkada, we were able to recreate and debug the issue. For T5 models having been trained using BF16
dtype, following changes are required because load_in_8bit
will have certain layers in FP32.
model.enable_input_require_grads()
as mentioned above for proper functioning of gradient checkpointing- diff snippet for lm_head casting:
class CastOutputToFloat(nn.Sequential):
- def forward(self, x): return super().forward(x).to(torch.float32)
+ def forward(self, x): return super().forward(x.to(torch.float16)).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)
- Disable
autocast
in the train loop. When using trainer, don't use any mixed precision arg such asfp16
/bf16
in trainer arguments.
here is a working colab example: https://colab.research.google.com/drive/1M1hFb5Rr_MSDKByHRqMP_9AZQ5TKj--H?usp=sharing
let us know if that resolves the issue
from peft.
I can confirm this worked for flan-t5-base:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import torch
import torch.nn as nn
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained(
"google/flan-t5-base",
load_in_8bit=True,
device_map='auto',
)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
for param in model.parameters():
param.requires_grad = False # freeze the model - train adapters later
if param.ndim == 1:
# cast the small parameters (e.g. layernorm) to fp32 for stability
param.data = param.data.to(torch.float32)
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.gradient_checkpointing_enable() # reduce number of stored activations
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
class CastOutputToFloat(nn.Sequential):
def forward(self, x): return super().forward(x).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)
from peft import LoraConfig, get_peft_model
config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q", "v"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
import transformers
from datasets import load_dataset
data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples['quote']), batched=True)
trainer = transformers.Trainer(
model=model,
train_dataset=data['train'],
args=transformers.TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=100,
max_steps=200,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir='outputs'
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
trainer.train()
from peft.
@younesbelkada Surprisingly, it doesn't seem like the model is learning / the loss is converging (compared to the same loop in float32)
from peft.
Hello @iliasmiraoui, could you provide a minimal reproducible script that we can run?
from peft.
@iliasmiraoui now you should use model.enable_input_require_grads()
, I have updated the Colab notebook accordingly
from peft.
Very cool! Don't hesitate to share your adapters on the Hub once trained! 🤗
from peft.
Related Issues (20)
- Clarification needed on Adapter Heads in PEFT HOT 2
- DoRA support for Embedding HOT 1
- Loading trained peft model results in random adapter weights each time HOT 4
- Example for Prompt-based methods fails with `expected sequence of length` mismatch
- Add Support for IA3 Adapters in add_weighted_adapter Method, Currently facing issue that 'IA3Model' object has no attribute 'add_weighted_adapter' HOT 9
- Saved weights differ from the original model HOT 14
- DoRA uses lots of GPU VRAM due to fp32 upcasting HOT 4
- How to convert a loha safetensor trained from diffusers to webui format HOT 1
- Add Support for IA3 Adapters in add_weighted_adapter Method HOT 5
- Outdated utility function: No attribute get_module_class_from_name in FullyShardedDataParallelPlugin HOT 1
- Error while loading PEFT lora model HOT 4
- The FSDP example fails to run with "ValueError: Must flatten tensors with uniform requires_grad when use_orig_params=False" HOT 2
- OOM error while QLoRA+Deepspeed fine tuning of Llama3-70B model on 4xA100-40GB gpus HOT 2
- Support merge_and_unload for IA3 Adapters with 4-bit and 8bit Quantization models
- model merge_and_unload do not support layer_replication HOT 5
- OOM with Phi-3-mini (3.8B) on 83.5GB RAM due to LoftQ HOT 4
- Ignore keys for modules to save HOT 1
- PeftModel failing to load after finetuning. Size Mismatch Error HOT 2
- Can peft support ColumnParallelLinear? HOT 1
- how to finetune whisper model with 'initial_prompt' HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from peft.