issues
search
google
/
flax
Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.15k
stars
648
forks
source link
Add logical axis global context support for NNX
#4350
Closed
IvyZX
closed
1 week ago
IvyZX
commented
3 weeks ago
Makes it possible to use the logical axis context manager in NNX annotations.
Move the logical axis context annotation and some rule inference part to
flax.core.spmd
, to be shared by Linen and NNX.
Some small typo fix.
flax.core.spmd
, to be shared by Linen and NNX.