pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
475 stars 23 forks source link

Support for Group Query Attention #82

Closed tugot17 closed 3 days ago

tugot17 commented 3 days ago

At the moment for 'torch==2.6.0.dev20241114+cu124' I get an error ValueError: Expect query and key/value to have the same number of heads but got Hq=32 and Hkv=8. Try setting enable_gqa=True for GQA when I try to it for a setting with different number of kv and q heads.

Do you plan to add support for Grou-Query Attention in the future versions?

*edit: skill issue, it is actually available as a param: enable_gqa