Closed pass-lin closed 1 month ago
I worked on this today, seems like it works now so I'll push a new release w/ it once I've verified all tests.
Check https://github.com/nshepperd/flash_attn_jax/releases/tag/v0.2.0, should work now. Also you could install it with pip now, pip install flash-attn-jax==0.2.0
, though the version on pypi is only for cuda 12.3 because... idk pypi is limited like that.
flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1)) q: (batch_size, seqlen, nheads, headdim) k: (batch_size, seqlen, nheads_k, headdim) v: (batch_size, seqlen, nheads_k, headdim)
except inputs like it