cgarciae / treex

A Pytree Module system for Deep Learning in JAX
https://cgarciae.github.io/treex/
MIT License
214 stars 17 forks source link

[WIP] Add Attention module #66

Open lkhphuc opened 2 years ago

lkhphuc commented 2 years ago

Adding attention module as a wrapper around flax.linen.attention.

I think the wrapper is correct, but I can not get the test_equivalance to pass if using Initializer that need rng. I think there's some mismatch between the next_key() and my manual emulation of it.

Todo: