(Unofficial) PyTorch implementation of grouped-query attention (GQA) from "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (https://arxiv.org/pdf/2305.13245.pdf)
Hi!, I'm trying to replicate your implementation with Llama 2-13B and 7B, but curiously the runtimes didn't make sense (llama 2 gqa > llama 2 WITHOUT gqa) there is a little difference between my code and yours.
I do mean pooling to attention weights to group them as GQA paper, then I load this new attn_wts with load_state_dict method, but as I said, it didn't work.
Then, I tried your implementation out-of-the-box, I simply download the repo and copy-paste the t5 example, but it doesn't seem to work either.
%%time
outputs = t5.generate(input_ids, max_new_tokens=25)
CPU times: user 2.23 s, sys: 0 ns, total: 2.23 s
Wall time: 97 ms
%%time
outputs = t5_gqa.generate(input_ids, max_new_tokens=25)
CPU times: user 9.78 s, sys: 0 ns, total: 9.78 s
Wall time: 414 ms
Hi!, I'm trying to replicate your implementation with Llama 2-13B and 7B, but curiously the runtimes didn't make sense (llama 2 gqa > llama 2 WITHOUT gqa) there is a little difference between my code and yours.
I do mean pooling to attention weights to group them as GQA paper, then I load this new attn_wts with load_state_dict method, but as I said, it didn't work.
Then, I tried your implementation out-of-the-box, I simply download the repo and copy-paste the t5 example, but it doesn't seem to work either.
I use the next packages in python 3.10:
transformers 4.36.2
torch 2.0.1+cu117
accelerate 0.25.0
Hardware: A100 80GB