FasterDecoding / SnapKV

200 stars 8 forks source link

Group Query Attention #22

Open SimJeg opened 2 months ago

SimJeg commented 2 months ago

Hello,

Could you clarify how you handle group query attention ? For instance in Mistral 7B, there are 8 key value heads and 32 heads. So a given key-value pair is associated with 4 different queries and hence 4 different attention weights. How do you aggregate these 4 values ? I do see the num_key_value_groups variable in the update_kv method but it is not used.

Thanks !

WendyH1108 commented 1 month ago

Thanks for the question. We keep the head dimension intact as https://github.com/FasterDecoding/SnapKV/blob/82135ce2cc60f212a9ba918467f3d9c8134e163f/snapkv/monkeypatch/mistral_hijack_4_37.py#L97. In our update_kv, we also keep the head dimension along calculations.

SimJeg commented 1 month ago

I'm still confused.

With group query attention there are more queries than keys and values (e g. 32 queries for 8 keys). SnapKV is based on filtering keys and values with the latest queries. So for each key / value you get for instance 4 different scores. How do you average these 4 scores into a single one ? In other words past key values should have 8 heads and not 32, is it the case in snapkv ?

Le sam. 26 oct. 2024, 06:44, yingbinghuang @.***> a écrit :

Thanks for the question. We keep the head dimension intact as https://github.com/FasterDecoding/SnapKV/blob/82135ce2cc60f212a9ba918467f3d9c8134e163f/snapkv/monkeypatch/mistral_hijack_4_37.py#L97. In our update_kv, we also keep the head dimension along calculations.

— Reply to this email directly, view it on GitHub https://github.com/FasterDecoding/SnapKV/issues/22#issuecomment-2439336738, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADE64VPAULKAN7DECWWT3ILZ5MM4JAVCNFSM6AAAAABPJ6EEFWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMZZGMZTMNZTHA . You are receiving this because you authored the thread.Message ID: @.***>

FdyCN commented 2 weeks ago

I have the same question. seems there is no "avg" or other process between repeated kv-heads to reduce to GQA. Or just change GQA to MHA by always repeating kv head??

FFY0 commented 2 weeks ago

Hello, I am the author of Ada-KV, a follow-up work to SnapKV. Recently, we try to integrate GQA support into SnapKV and our Ada-KV. Experimental results show that, after enabling GQA with only 25% of the original cache size in Mistral-7B-Instruct-v0.2, both SnapKV and our Ada-KV continue to perform well, with only a slight quality drop. We have made our code and preliminary results on LongBench publicly available in our repository. We sincerely thank the SnapKV team for releasing their paper and code, which greatly contributed to advancing our research!