google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.04k stars 639 forks source link

Support for RNN (NNX) #4259

Open zinccat opened 1 week ago

zinccat commented 1 week ago

seems like that RNN family is not currently supported in the NNX api

zinccat commented 1 week ago

started working on porting: https://github.com/zinccat/rnn_nnx

cgarciae commented 1 week ago

Hey @zinccat, if you want to send PRs happy to review. We haven't ported any of the cells or the RNN class itself.