Adding attention module as a wrapper around flax.linen.attention.
I think the wrapper is correct, but I can not get the test_equivalance to pass if using Initializer that need rng. I think there's some mismatch between the next_key() and my manual emulation of it.
Todo:
[x] Pass test initialization with stochastic init.
Adding attention module as a wrapper around
flax.linen.attention
.I think the wrapper is correct, but I can not get the
test_equivalance
to pass if using Initializer that need rng. I think there's some mismatch between thenext_key()
and my manual emulation of it.Todo: