philschmid / deep-learning-pytorch-huggingface

MIT License
659 stars 152 forks source link

Llama patch for FlashAttention support fails with use_cache #26

Open qmdnls opened 1 year ago

qmdnls commented 1 year ago

I came across your llama_patch.py when looking to patch Llama for inference myself and unless I'm doing something wrong the implementation fails when use_cache=True and past_key_value is not None.

Specifically during geneartion with use_cache=True in this line query_states will have sequence length 1 while key_states and value_states will have length 1 + past_key_value[0].shape[-2] and thus these tensors won't stack.

https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/05d83eaa3c2ad6088227fa26dffb097e06439aef/training/utils/llama_patch.py#L76C3-L76C3

I think this is also the other llama patches referenced in the comments don't support flash attention + kv cache at the same time. Not sure if there's a clever workaround?

philschmid commented 1 year ago

Hey @qmdnls,

It could be very true what you say. I created the patch only for training, where you use gradient checkpointing and no cache.

If you are interested in inference i recommend checking text-generation-infernece

qmdnls commented 1 year ago

I see, no worries! Just came across this and thought I would let you know since the patch seemed to specifically implement the case with past_key_value unlike the other referenced implementations.

Thanks for the pointer, I will have a look!