Open saeid93 opened 2 weeks ago
Thanks a lot for reporting this. Indeed, the handling of modules_to_save
can be messy at times and the outcome you show should be avoided. I don't have the opportunity to test this right now, but my assumption is that this extra module won't disrupt the results for adapter 2 because it is a copy of the original layer and behaves exactly the same, as that right?
No worries, glad to be of any help. As far as I have tested it should be fine and using the correct loaded layer, the only problem is redundancy in loaded modules. I also dug a bit deeper and noticed that the problem originates from this function:
https://github.com/huggingface/peft/blob/162d7e57ee0088f42eb0f26150bd9170d30f3637/src/peft/peft_model.py#L966
For an unknown reason when using load_adapter
: https://github.com/huggingface/peft/blob/162d7e57ee0088f42eb0f26150bd9170d30f3637/src/peft/peft_model.py#L969
The set is not being updated to only the new layer and it will still hold the old layer in the set too (which shouldn't). For example if I manually hack the above script the problem will be solved:
...
# Apply and save the second adapter
os.makedirs(adapter_2_path, exist_ok=True)
model_with_lora_2 = get_peft_model(base_model, lora_config_2, adapter_name="adapter_2")
model_with_lora_2.save_pretrained(adapter_2_path)
# Load a fresh base model and wrap it in PeftModel by loading the first adapter
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
peft_model = PeftModel.from_pretrained(base_model, os.path.join(adapter_1_path, "adapter_1"), adapter_name="adapter_1")
peft_model.modules_to_save = {"wte"} # <----------- HERE manually changing the modules_to_save
# Load the second adapter into the PeftModel
peft_model.load_adapter(os.path.join(adapter_2_path, "adapter_2"), adapter_name="adapter_2")
...
Okay, I managed to reproduce the error. My tentative fix is in #2220. Right now, there is a CI issue but it should hopefully resolve itself soon. Meanwhile, it would be great if you could check if the fix makes sense to you.
Just a side note, when adding multiple adapters, don't use get_peft_model
twice, use peft_model.add_adapter
instead. But even with that change, the problem is reproducible.
System Info
Python 3.11.9 transformers==4.40.2 peft==0.11.2
Who can help?
@BenjaminBossan A bug occurs in the PEFT library when using multiple LoRA adapters, each with a unique
modules_to_save
configuration. The issue arises when themodules_to_save
from the first LoRA adapter (e.g.,adapter_1
) is applied to subsequent adapters (e.g.,adapter_2
), rather than maintaining independent configurations. As a result, modules specified inmodules_to_save
foradapter_1
also appear inadapter_2
, leading to unintended behavior and possibly affecting fine-tuning accuracy. This incorrect handling ofmodules_to_save
causes duplicate entries where only the respective LoRA adapter’s modules should be saved.Information
Tasks
examples
folderReproduction
The following example code demonstrates this issue, displaying the model structure where
adapter_2
contains modules meant only foradapter_1
.Example Code
The code output will be:
Expected behavior
As you see adapter 2 is also built for the "lm_head" module to which it shouldn't, the expected output is shown below: