Open qmdnls opened 1 year ago
Hey @qmdnls,
It could be very true what you say. I created the patch only for training, where you use gradient checkpointing and no cache.
If you are interested in inference i recommend checking text-generation-infernece
I see, no worries! Just came across this and thought I would let you know since the patch seemed to specifically implement the case with past_key_value
unlike the other referenced implementations.
Thanks for the pointer, I will have a look!
I came across your llama_patch.py when looking to patch Llama for inference myself and unless I'm doing something wrong the implementation fails when
use_cache=True
andpast_key_value
is notNone
.Specifically during geneartion with
use_cache=True
in this linequery_states
will have sequence length 1 whilekey_states
andvalue_states
will have length1 + past_key_value[0].shape[-2]
and thus these tensors won't stack.https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/05d83eaa3c2ad6088227fa26dffb097e06439aef/training/utils/llama_patch.py#L76C3-L76C3
I think this is also the other llama patches referenced in the comments don't support flash attention + kv cache at the same time. Not sure if there's a clever workaround?