Git Product home page Git Product logo

Comments (4)

BenjaminBossan avatar BenjaminBossan commented on June 2, 2024 2

Thanks for reporting. I dug a bit deeper. The offending line, at least in my setup, is:

abs_diff = torch.abs(weight_divabs - L_reshaped) # (L, B, 2**K)

With the incoming weight having a shape of (3072, 3072), we have:

  • weight_divabs => 147456, 64, 1
  • L_reshaped => 1, 256
  • abs_diff => 147456, 64, 256

So abs_diff tries to allocate 9 GB of memory (with float32). I wonder if we can avoid such a huge shape. Pinging @yxli2123.

What you could try right now is to use the replace_lora_weights_loftq function. This allows you to load the model with bnb quantized weights, i.e. with lower memory requirement, and apply LoftQ on the fly with relatively little overhead. I tried this on my machine and memory was consistently < 5GB:

import time
import gc
import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, replace_lora_weights_loftq

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)

checkpoint_path = "microsoft/Phi-3-mini-4k-instruct"
model_kwargs = dict(
    use_cache=False,
    trust_remote_code=True,
    attn_implementation="flash_attention_2", 
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=BitsAndBytesConfig(load_in_4bit=True),
)
peft_config = LoraConfig(
    r = 8,
    lora_alpha = 32,
    lora_dropout = 0.05,
    bias = "none",
    task_type =  "CAUSAL_LM",
    target_modules = "all-linear",
    modules_to_save = None,
    use_rslora = True, 
)

model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)
model = prepare_model_for_kbit_training(model)
#model = model.to("cpu")
torch.cuda.empty_cache()
gc.collect()
model = get_peft_model(model, peft_config)
replace_lora_weights_loftq(model)   # takes a couple of minutes

Note that using this approach is more memory efficient, but it might not perform as well, at least not without making use of the callback feature described in this LoftQ init notebook.

from peft.

adamamer20 avatar adamamer20 commented on June 2, 2024

@BenjaminBossan Thank you for the advice, your method works! So the issue was that the weight matrix (3072,3072) was being quantized all at once and there wasn't enough space available for the necessary computations.

Can you clarify however what does replace_lora_weights_loftq do? Because from the source code it seems to assume that the Lora adapter weights are quantized already, but there's no mention of quantization in your LoraConfig. Are the Lora weights initialized as quantized because the model weights are quantized?

from peft.

BenjaminBossan avatar BenjaminBossan commented on June 2, 2024

Can you clarify however what does replace_lora_weights_loftq do? Because from the source code it seems to assume that the Lora adapter weights are quantized already, but there's no mention of quantization in your LoraConfig. Are the Lora weights initialized as quantized because the model weights are quantized?

The LoRA weights are never quantized, regardless of whether the base model is quantized or not. This is necessary because quantized weights cannot be trained, and we want the LoRA weights to be trained. But since the total number of parameters of the LoRA weights is typically small, this should still result in less memory being used than full fine-tuning.

from peft.

adamamer20 avatar adamamer20 commented on June 2, 2024

You're right, I forgot that quantization can be used only during inference. Thank you very much.

from peft.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.