fkodom / grouped-query-attention-pytorch

(Unofficial) PyTorch implementation of grouped-query attention (GQA) from "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (https://arxiv.org/pdf/2305.13245.pdf)
MIT License
117 stars 6 forks source link

gqa model runtime > no gqa model runtime #4

Open Adonai02 opened 8 months ago

Adonai02 commented 8 months ago

Hi!, I'm trying to replicate your implementation with Llama 2-13B and 7B, but curiously the runtimes didn't make sense (llama 2 gqa > llama 2 WITHOUT gqa) there is a little difference between my code and yours.

I do mean pooling to attention weights to group them as GQA paper, then I load this new attn_wts with load_state_dict method, but as I said, it didn't work.

Then, I tried your implementation out-of-the-box, I simply download the repo and copy-paste the t5 example, but it doesn't seem to work either.

%%time
outputs = t5.generate(input_ids, max_new_tokens=25)
CPU times: user 2.23 s, sys: 0 ns, total: 2.23 s
Wall time: 97 ms

%%time
outputs = t5_gqa.generate(input_ids, max_new_tokens=25)
CPU times: user 9.78 s, sys: 0 ns, total: 9.78 s
Wall time: 414 ms

I use the next packages in python 3.10:

Hardware: A100 80GB

ylacombe commented 3 months ago

Hey @Adonai02 and @fkodom, have you been able to figure out what was going on ? Thanks!