Open Adonai02 opened 10 months ago
That would be nice but a bit outside the scope of transformers! Would be nice if you have a working example! What I recommend is to register a load_state_dict hook that converts the checkpoints on the fly. The benchmark should run on different num kv heads as some shapes might be less optimal ? That would be my intuitiion. Also a single head (MQA) should be always faster than MHA
If I understand correctly, here is displayed an attempt to implement GQA on a non GQA Llama 2 13b model? If that's the case, and despite the slight loss of performance observed, does the context size in VRAM gets diminished as GQA allows, and is the perplexity of the model affected? If that's not the case, sorry for misunderstanding!
Feature request
It would be nice if when I choose different key_value_heads (key_value_heads < attention_heads) on config's model, automatically the attn weights were computed by mean pooling. Right now, if I do this, it gives me the next error.
key_value_heads = 4
Motivation
Make models faster, e.g Llama 2 13B, Llama 7B, Mistral 7B etc.
Your contribution
I tried to do a simple implementation. But it gives me inconsistent results. GQA model is slower than No GQA model.
Results GQA LLAMA
NO GQA LLAMA
I don't know if I'm misunderstanding something, please let me know if you can see something I can't