huggingface / peft

🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning.
https://huggingface.co/docs/peft
Apache License 2.0
16.3k stars 1.6k forks source link

With tied embeddings adapter merged to tied layers #2018

Closed ltoniazzi closed 1 month ago

ltoniazzi commented 2 months ago

System Info

peft=0.12.0 transformers =4.44.0

Who can help?

No response

Information

Tasks

Reproduction

With Gemma2, a model where tie_word_embeddings = True, using target_modules=["lm_head"] and merging the adapter leads to merging the adapter to the tied/embedding layer, which is incorrect.

from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it")
# model.config.tie_word_embeddings = False # doing this does not change the outcome

config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["lm_head"],
    bias="none",
    task_type="CAUSAL_LM",
    init_lora_weights=False,
)
clone_lm_head_orig = model.lm_head.weight.data.clone()
model = get_peft_model(model, config)

# Check embed_tokens and lm_head point to the same data
assert model.model.model.embed_tokens.weight.data.data_ptr() == model.model.lm_head.weight.data.data_ptr()

# Merge adapter in the base model
model.merge_and_unload(safe_merge=True, adapter_names=["default"])

# Check adapter is merged
assert not torch.equal(clone_lm_head_orig, model.model.lm_head.weight.data)

# Check embedding layer is unchanged by the lm_head adapter merging
assert model.model.model.embed_tokens.weight.data.data_ptr() != model.model.lm_head.weight.data.data_ptr(), "Embedding layer should have not changed"

Expected behavior

I think that merging should not succeed silently, but either a:

Related issues

BenjaminBossan commented 2 months ago

Thanks for opening this issue. Yes, I agree that this is an easy source of errors, and having a warning would help.

The main reason why this is not implemented yet is that merging is a layer-level operation in PEFT. The individual layer can, however, not know if its weights are tied are not. Therefore, we cannot easily check for this. It could be possible to refactor this to work differently but I don't see an easy way.

We could still try to make an educated guess based on model.config.tie_word_embeddings and the actual target_modules and that should help most users who face this situation. If you are interested in working on this, feel free to create a PR. Otherwise, I'll put this on the backlog and work on this when I have a bit of time on my hands.

Make the B matrix non-zero

This can also be achieved by passing init_lora_weights=False to the LoraConfig :)

ltoniazzi commented 2 months ago

If you are interested in working on this, feel free to create a PR.

Yes sure, happy to have a go at it later this week!

BenjaminBossan commented 2 months ago

Fantastic, thanks. Don't hesitate to ask me if something is unclear, or to create a draft PR for early feedback.

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

BenjaminBossan commented 1 month ago

Resolved via #2025.