nshepperd / flash_attn_jax

JAX bindings for Flash Attention v2
BSD 3-Clause "New" or "Revised" License
62 stars 0 forks source link

it will support group attention? #4

Closed pass-lin closed 1 month ago

pass-lin commented 2 months ago

flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1)) q: (batch_size, seqlen, nheads, headdim) k: (batch_size, seqlen, nheads_k, headdim) v: (batch_size, seqlen, nheads_k, headdim)

except inputs like it

nshepperd commented 1 month ago

I worked on this today, seems like it works now so I'll push a new release w/ it once I've verified all tests.

nshepperd commented 1 month ago

Check https://github.com/nshepperd/flash_attn_jax/releases/tag/v0.2.0, should work now. Also you could install it with pip now, pip install flash-attn-jax==0.2.0, though the version on pypi is only for cuda 12.3 because... idk pypi is limited like that.