google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.15k stars 648 forks source link

Add NNX transforms `nnx.while_loop` and `nnx.switch` #4343

Closed IvyZX closed 3 weeks ago

IvyZX commented 4 weeks ago
cgarciae commented 3 weeks ago

BTW: Can we move while_loop to iteration.py? Been meaning to remove transforms.py which previously contained all the transforms in single file that became too big.