Closed IvyZX closed 3 weeks ago
Implemented nnx.switch similar to nnx.cond
nnx.switch
nnx.cond
Implemented nnx.while_loop
nnx.while_loop
jax.lax.while_loop
BTW: Can we move while_loop to iteration.py? Been meaning to remove transforms.py which previously contained all the transforms in single file that became too big.
while_loop
iteration.py
transforms.py
Implemented
nnx.switch
similar tonnx.cond
Implemented
nnx.while_loop
jax.lax.while_loop
, no reference structure change is allowed insidennx.while_loop
for NNX objects.