philschmid / deep-learning-pytorch-huggingface

MIT License
659 stars 152 forks source link

LLama 2 Flash Attention Patch Not Working For 70B #32

Open mallorbc opened 1 year ago

mallorbc commented 1 year ago

The flash attention patch seems to be working for LLama 7B and LLama 13B(though I need to confirm more than just a successful backward pass). However, for whatever reason, for LLama 70B, I am getting an error like the following:

File "/datadrive/Finetune_LLMs/finetuning_repo/llama_patch.py", line 47, in forward key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) RuntimeError: shape '[1, 190, 64, 128]' is invalid for input of size 194560

mallorbc commented 1 year ago

For anyone else wondering why: https://github.com/philschmid/deep-learning-pytorch-huggingface/pull/30

philschmid commented 1 year ago

Hey @mallorbc,

I needed to revert #30 since it broke the training for 7B and 13B i haven't had the chance to look at it again.

mallorbc commented 1 year ago

@philschmid Thanks for getting back to me and thanks for your work and blog post!

I eventually saw that.

I was able to use your work to get finetuning working for 7B, 13B, and 70B. Instead of having one forward pass, having two based on what model is being used is good enough for me!

Used it here: https://github.com/mallorbc/Finetune_LLMs/blob/main/finetuning_repo/llama_patch.py

Thanks again!

philschmid commented 1 year ago

@mallorbc an nice! I try to make it compatible with both soonish. But we are also working on adding native support in transformers so in a few weeks not longer need to patch those.

mallorbc commented 1 year ago

@philschmid The old 70B patch, while it supports a backward and forward pass, still has issues.

When I try to generate text with the model after training with Qlora, I don't get the expected results. When I try to use text-generation-inference, I get shape issues as well.

These issues do not exist for the 7B and 13B model, which works great!

Just thought I would let you know. Thanks!

mallorbc commented 1 year ago

This repo here has a working implementation for all models, that being 7B, 13B, and 70B. It's licensed as GPL 3.0, but for my repo which is APGL, that is fine.

https://github.com/oKatanaaa/llama-flash-attention-patch/tree/master

He gets part of the solution from here, which is licensed as Apache 2.0: https://github.com/LAION-AI/Open-Assistant/blob/04fa9a24b2a58c8885b8aa6a2eb02b18de6b4961/model/model_training/models/patching_llama.py

May be useful.