Open SimJeg opened 2 months ago
Thanks for the question. We keep the head dimension intact as https://github.com/FasterDecoding/SnapKV/blob/82135ce2cc60f212a9ba918467f3d9c8134e163f/snapkv/monkeypatch/mistral_hijack_4_37.py#L97. In our update_kv, we also keep the head dimension along calculations.
I'm still confused.
With group query attention there are more queries than keys and values (e g. 32 queries for 8 keys). SnapKV is based on filtering keys and values with the latest queries. So for each key / value you get for instance 4 different scores. How do you average these 4 scores into a single one ? In other words past key values should have 8 heads and not 32, is it the case in snapkv ?
Le sam. 26 oct. 2024, 06:44, yingbinghuang @.***> a écrit :
Thanks for the question. We keep the head dimension intact as https://github.com/FasterDecoding/SnapKV/blob/82135ce2cc60f212a9ba918467f3d9c8134e163f/snapkv/monkeypatch/mistral_hijack_4_37.py#L97. In our update_kv, we also keep the head dimension along calculations.
— Reply to this email directly, view it on GitHub https://github.com/FasterDecoding/SnapKV/issues/22#issuecomment-2439336738, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADE64VPAULKAN7DECWWT3ILZ5MM4JAVCNFSM6AAAAABPJ6EEFWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMZZGMZTMNZTHA . You are receiving this because you authored the thread.Message ID: @.***>
I have the same question. seems there is no "avg" or other process between repeated kv-heads to reduce to GQA. Or just change GQA to MHA by always repeating kv head??
Hello, I am the author of Ada-KV, a follow-up work to SnapKV. Recently, we try to integrate GQA support into SnapKV and our Ada-KV. Experimental results show that, after enabling GQA with only 25% of the original cache size in Mistral-7B-Instruct-v0.2, both SnapKV and our Ada-KV continue to perform well, with only a slight quality drop. We have made our code and preliminary results on LongBench publicly available in our repository. We sincerely thank the SnapKV team for releasing their paper and code, which greatly contributed to advancing our research!
Hello,
Could you clarify how you handle group query attention ? For instance in Mistral 7B, there are 8 key value heads and 32 heads. So a given key-value pair is associated with 4 different queries and hence 4 different attention weights. How do you aggregate these 4 values ? I do see the
num_key_value_groups
variable in theupdate_kv
method but it is not used.Thanks !