unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi, Qwen 2.5 & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
18.23k stars 1.27k forks source link

Gemma 2B LoRA merging is broken #219

Closed oKatanaaa closed 1 month ago

oKatanaaa commented 8 months ago

I've trained Gemma 2B in 16bit with LoRA. With adapters loaded separately everything works just fine. But after merging the adapters, the model becomes literally unusable.

image

On the screenshot:

Here is the code used to load the models:

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "trained", # YOUR MODEL YOU USED FOR TRAINING
    max_seq_length = 8192,
    dtype = None,
    load_in_4bit = False,
    resize_model_vocab=256001
)
tokenizer.add_special_tokens({'additional_special_tokens': ['<|im_start|>']})
tokenizer = get_chat_template(
    tokenizer,
    chat_template="chatml",
    map_eos_token=True
)
FastLanguageModel.for_inference(model)
model.save_pretrained_merged("merged", tokenizer, save_method = "merged_16bit")
model2, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "merged", # YOUR MODEL YOU USED FOR TRAINING
    max_seq_length = 8192,
    dtype = None,
    load_in_4bit = False,
    #resize_model_vocab=32001
)
FastLanguageModel.for_inference(model2)

Model was trained with ChatML format, hence token adding stuff. resize_model_vocab parameter is a workaround I added to load vocab of different size.

Also, saved adapters weigh 6 GB, is that alright? Note that merged model is 5.7 GB. I believe adapters should be hundred MB tops (maybe a GB with saved vocab and lm_head), but presumably the whole model got saved. image

Important note: during training in modules_to_save param I passed ["embed_tokens", "lm_head"] to train the new ChatML tokens. Although I am not sure how that plays with the fact the Gemma's embed_tokens weights are tied with lm_head (I believe?). Maybe that's actually the reason why merging fails? Like you have to pass in only embed_tokens, otherwise everything will break (just a hypothesis)

Dependencies:

accelerate                0.27.2
datasets                  2.16.1
huggingface-hub           0.20.3
ipykernel                 6.29.0
ipython                   8.15.0
jedi                      0.18.1
Jinja2                    3.1.2
numpy                     1.26.2
peft                      0.8.2
safetensors               0.4.2
scikit-learn              1.4.0
scipy                     1.12.0
sentence-transformers     2.3.1
sentencepiece             0.1.99
sympy                     1.12
tensorboardX              2.6.2.2
torch                     2.1.2
torchaudio                2.1.2
torchelastic              0.2.2
torchvision               0.16.2
transformers              4.38.1
triton                    2.1.0
trl                       0.7.11
unsloth                   2024.3
xformers                  0.0.23.post1
danielhanchen commented 8 months ago

OHHH I forgot to say get_chat_template is broken for Gemma :( What you're looking for is "gemma_chatml" instead of "chatml", and it'll auto auto <|im_start|> and <|im_end|>. You also do not need to add any special tokens since I handle it internally!

oKatanaaa commented 8 months ago

My findings so far:

  1. In the PEFT model lm_head and embed_tokens have different data pointers (is that okay?). But the tensors are equal.
  2. In the merged model data pointers are the same. Both tensors are equal as well.
  3. lm_head and embed_tokens in the merged model are equal to lm_head and embed_tokens in the PEFT model.

image

So the problem is not the embeddings weights, but something else

oKatanaaa commented 8 months ago

OHHH I forgot to say get_chat_template is broken for Gemma :( What you're looking for is "gemma_chatml" instead of "chatml", and it'll auto auto <|im_start|> and <|im_end|>. You also do not need to add any special tokens since I handle it internally!

In my setting it works fine, at least for PEFT model. Just in case, both models were loaded in a single notebook, so the problem can't be the tokenizer (because model works great, but model2 is simply dead). Must be something with the merging

oKatanaaa commented 8 months ago

I looked at the contents of adapter_model.safetensors file and found that it contains 3x more embedding weights, which explains the size. But it leaves me somewhat confused since all 3 tensors are equal to each other. Like why save all three and not just one from modules_to_save?

image image

I first thought it was some quirk with Gemma's LoRA saving. Checked Mistral LoRA adapter files (it weighs 2.1GB), same thing.

oKatanaaa commented 8 months ago

Might be linked to this: https://huggingface.co/google/gemma-2b/discussions/21

Although I'm confused why PEFT model works so well even with the added token.

oKatanaaa commented 8 months ago

While debugging, I found that weights for the layernorm layers in the merged model differ from those in the PEFT one.

Weights of the PEFT model: image

Weights of the merged model: image

It's just values from PEFT model, but +1? Is there some trick to initializing rms weights in gemma?

To clarify, I looked at the value of W variable here: image image

oKatanaaa commented 8 months ago

Alright, I added this stupid ass fix (in unsloth_save_model) and now everything works fine:

image

Although the outputs are not exactly the same, that's way better than before:

image

Hope that helps 🫡

oKatanaaa commented 8 months ago

The fix didn't quite fix it. For longer inputs the merged model still diverges. I guess huge error accumulation is in play.

image

oKatanaaa commented 8 months ago

Sorry for lots of messages.

I suspect the culprit is this line: https://github.com/unslothai/unsloth/blob/main/unsloth%2Fmodels%2Fgemma.py#L362

Although I understand the motivation, there are two problems with it:

  1. It breaks merging, that's what the current issue is about. With each merge there's += 1 happening in postpatch
  2. It may affect training in unexpected ways since the computational graph is altered and gradients with respect to layernorm weights (when unfreezed) are computed differently (although a constant shift should not affect the derivative, it's safer to assume it does).

As the most optimal solution in terms of compatibility/memory-compute savings I see a custom triton layernorm kernel. Using the vanilla implementation might as well be good enough, just need to do a couple of tests that's all

danielhanchen commented 8 months ago

@oKatanaaa Oh my - thanks so so much for all the debugging - extremely appreciate it!! I just woke up so much apologies missed the convo - I was gonna say it's ironic I was fixing Gemma bugs but didn't check Unsloth's own issues!! 😆

Great you found the +1 culprit - I actually totally forgot to minus 1 during merging - but if according to your analysis +1 then minus 1 reduces accuracy, I'll just copy paste the kernel and add 1 - i'll do that in minutes and push it in :)

On the saving modules - interesting - I have never interacted with saving modules since I normally only finetune the rest and leave the lm_head and embedding matrix alone. I shall investigate this later today!!

Again thanks so much on the help - extremely appreciate it! I'll at you in the fix :)

danielhanchen commented 8 months ago

Oh wait on the layernorms - do you unfreeze them to train on them?

oKatanaaa commented 8 months ago

Oh wait on the layernorms - do you unfreeze them to train on them?

Nope, didn't touch those during training. But thought it was worth pointing out potential issues

danielhanchen commented 8 months ago

Ok I finally fixed it! I took your advice and rewrote the kernels and isolated it out. Hopefully GGUF saving works now (and merged 16bit)

oKatanaaa commented 8 months ago

Can confirm that merging in 16bit now works fine. No more degenerate outputs.

image

Guess the difference in responses we can attribute to rounding errors during LoRA merge (I've seen it with other models as well), I'm good with that.

Thanks for the fix, well done!

danielhanchen commented 8 months ago

Ye it's possible it was rounding and some other issues :)