Open zinccat opened 1 week ago
This is fun :) On a quick skim, I think the parts you haven't adopted yet are:
yeah, I'm trying to work on autograd first
I think autograd won't be so bad, but I think "captured" buffers would be much more difficult to do in Jax.
Just done a bit of straightforward adaptation to Jax, performance is ~20% slower: https://github.com/zinccat/flaxattention Thanks for the great project!