pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
434 stars 22 forks source link

A simple adaption to Jax #59

Open zinccat opened 1 week ago

zinccat commented 1 week ago

Just done a bit of straightforward adaptation to Jax, performance is ~20% slower: https://github.com/zinccat/flaxattention Thanks for the great project!

Chillee commented 6 days ago

This is fun :) On a quick skim, I think the parts you haven't adopted yet are:

  1. block-sparsity
  2. autograd support
  3. "captured" buffers
zinccat commented 6 days ago

yeah, I'm trying to work on autograd first

Chillee commented 6 days ago

I think autograd won't be so bad, but I think "captured" buffers would be much more difficult to do in Jax.