Git Product home page Git Product logo

Comments (11)

BenjaminBossan avatar BenjaminBossan commented on June 10, 2024

That's indeed very strange, it looks like the steps you take are indeed correct. Could you please paste the full error message? Is it possible to access your adapters publicly somewhere so that I can try to reproduce?

from peft.

YnezT0311 avatar YnezT0311 commented on June 10, 2024

Sure, here is the full error message, just several lines:

Exception has occurred: ValueError
Adapter id_2 not found in odict_keys(['id_1'])
  File "xxxx/code/merge_weight_peft.py", line 169, in main
    merged_model.set_adapter('id_2')
  File "xxxx/code/merge_weight_peft.py", line 207, in <module>
    app.run(main)
ValueError: Adapter id_2 not found in odict_keys(['id_1'])

I have upload the adapter_model.bin of two models I want to merge to YnezT/backdoor-sst2-bert-base-uncased and YnezT/clean-sst2-bert-base-uncased, respectively. What I'm trying to do is to merge this two and get a merged_model, then set that to the main adapter:

base_model = AutoModelForSequenceClassification.from_pretrained(
        "bert-base-uncased"
    )

merged_model = PeftModel.from_pretrained(base_model, backdoor_model_adapter, adapter_name="backdoor_model")
merged_model.load_adapter(clean_model_adapter, adapter_name="clean_model")
if merge_type == "ties-merge":
      """ ties-merge"""
      adapters = ["clean_model", "backdoor_model"]
      weights = [1.0, 1.0]
      adapter_name = "merge_model"
      density = 0.2
      combination_type = "ties"
      merged_model.add_weighted_adapter(adapters, weights, adapter_name, combination_type=combination_type, density=density)
      merged_model.set_adapter(adapter_name)

Thanks a lot!

from peft.

BenjaminBossan avatar BenjaminBossan commented on June 10, 2024

Thanks for the additional context. However, I could not reproduce the error. Below is a code snippet. Note that I used my own LoRA adapters, as yours are .bin files and not safetensors (just as a safety measure), but that shouldn't really change things.

from transformers import AutoModelForSequenceClassification
from peft import get_peft_model, LoraConfig, PeftModel

# creating adapters
config0 = LoraConfig(init_lora_weights=False)
config1 = LoraConfig(init_lora_weights=False)
base_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
peft_model = get_peft_model(base_model, config0)
peft_model.add_adapter("other", config1)
peft_model.save_pretrained("/tmp/peft/bert")

# same as your code
base_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
merged_model = PeftModel.from_pretrained(base_model, "/tmp/peft/bert/", adapter_name="backdoor_model")
merged_model.load_adapter("/tmp/peft/bert/other", adapter_name="clean_model")
adapters = ["clean_model", "backdoor_model"]
weights = [1.0, 1.0]
adapter_name = "merge_model"
density = 0.2
combination_type = "ties"
merged_model.add_weighted_adapter(adapters, weights, adapter_name, combination_type=combination_type, density=density)
merged_model.set_adapter(adapter_name)  # works
print(merged_model.active_adapters)  # shows ['merge_model']

What's also strange is that your error message refers to "id_1" and "id_2" but your code uses different adapter names. Are you sure that there isn't something else going on?

from peft.

YnezT0311 avatar YnezT0311 commented on June 10, 2024

Sorry, I just manually changed the name in the error msg for consistency with the initial query, but forgot to change the later code. Here id_1 stands for backdoor_model and id_2 for merge_model.

Could I ask for your log and Pytorch version? When running my code, I saw this:

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Not sure whether it is related.

from peft.

BenjaminBossan avatar BenjaminBossan commented on June 10, 2024

I have torch v2.2.0, transformers v4.39.0 and latest PEFT.

My log is a bit different to yours:

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

This is expected, since we use AutoModelForSequenceClassification, so a classification layer is added, which is untrained. When you do your fine-tuning, it should be added to modules_to_save in your LoraConfig. Regardless of that, I don't see how the error you described could possibly occur.

from peft.

YnezT0311 avatar YnezT0311 commented on June 10, 2024

Apologies for my late reply; I was preoccupied with something else last week.

I notice the differences between your code and mine.

config0 = LoraConfig(init_lora_weights=False)
config1 = LoraConfig(init_lora_weights=False)
base_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
peft_model = get_peft_model(base_model, config0)
peft_model.add_adapter("other", config1)
peft_model.save_pretrained("/tmp/peft/bert")

In your code, here, these two adapters are used for training the "same" model, i.e., the .save_pretrained will save these two adapters together. Is this required for merging?

What I did is somewhat similar to:

config0 = LoraConfig(init_lora_weights=False)
base_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
peft_model = get_peft_model(base_model, config0)
peft_model.save_pretrained("/tmp/peft/bert/main")

and

config1 = LoraConfig(init_lora_weights=False)
base_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
peft_model = get_peft_model(base_model, config1)
peft_model.save_pretrained("/tmp/peft/bert/other")

and then

base_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
merged_model = PeftModel.from_pretrained(base_model, "/tmp/peft/bert/main", adapter_name="backdoor_model")
merged_model.load_adapter("/tmp/peft/bert/other", adapter_name="clean_model")

I guess this is where the problem I encountered arises?

from peft.

YnezT0311 avatar YnezT0311 commented on June 10, 2024

I have tested it, and it appears not to be the reason. Both work. I might figure out the problem myself. Thanks!

from peft.

BenjaminBossan avatar BenjaminBossan commented on June 10, 2024

In your code, here, these two adapters are used for training the "same" model, i.e., the .save_pretrained will save these two adapters together. Is this required for merging?

I have tested it, and it appears not to be the reason.

I just want to confirm that this is not necessary and should not be the reason for the problem you encounter.

I might figure out the problem myself. Thanks!

Feel free to share new insights or questions that you may have in this issue.

from peft.

YnezT0311 avatar YnezT0311 commented on June 10, 2024

Thank you for your patience. I have identified where the problem lies, but I do not know how to solve it: If we specify the task_type in LoraConfig, an error occurs. For example, the following code will throw out an error.

from transformers import AutoModelForSequenceClassification
from peft import get_peft_model, LoraConfig, PeftModel, TaskType

# creating adapters
config0 = LoraConfig(task_type=TaskType.SEQ_CLS, init_lora_weights=False)
config1 = LoraConfig(task_type=TaskType.SEQ_CLS, init_lora_weights=False)
base_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
peft_model = get_peft_model(base_model, config0)
peft_model.add_adapter("other", config1)
peft_model.save_pretrained("tmp/peft/bert/")

# same as your code
base_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
merged_model = PeftModel.from_pretrained(base_model, "tmp/peft/bert/", adapter_name="backdoor_model")
merged_model.load_adapter("tmp/peft/bert/other", adapter_name="clean_model")
adapters = ["clean_model", "backdoor_model"]
weights = [1.0, 1.0]
adapter_name = "merge_model"
density = 0.2
combination_type = "ties"
merged_model.add_weighted_adapter(adapters, weights, adapter_name, combination_type=combination_type, density=density)
merged_model.set_adapter(adapter_name) # ValueError: Adapter merge_model not found in odict_keys(['backdoor_model'])
print(merged_model.active_adapters)

However, if we remove the task_type, the evaluation metric mismatch occurs. My current solution is to disable the evaluation phase in Trainer, and evaluate at the end by myself, but I assume there should be more standard solutions.

Exception has occurred: KeyError
'eval_loss'
  File "xxxx", line 127, in <module>
    clean_trainer.train()
KeyError: 'eval_loss'

from peft.

BenjaminBossan avatar BenjaminBossan commented on June 10, 2024

Thanks for digging further. I can reproduce the error and this is indeed caused by a bug in PEFT, which actually runs deeper than what is reported in this issue. I'll work on a bugfix and will get back to you once it's ready.

from peft.

BenjaminBossan avatar BenjaminBossan commented on June 10, 2024

I'm working on a solution for this. Some parts are addressed in the PR linked above but there is also another issue which is not easily fixed. The problem with your example is that if you want to use sequence classification, it means that we also need to train the classifier head. This isn't done with LoRA, instead the classifier head is trained using full fine-tuning (based on a copy of the weights of the original classifier head). When you have multiple adapters, each will have its own classifier head (once the PR is merged).

Now the problem is that you want to merge the two adapters. For the LoRA weights, there are a few ways that they can be merged. But these classifier heads cannot be merged, as they each have a fine-tuned copy of the full weights. Therefore, even with the bugfix, your original problem wouldn't be solved. In fact, we may even start to explicitly check this during merging and raise an error.

from peft.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.