Adds nnx.pure which behaves similar to nnx.split but returns a Pure pytree that can call anymethods from the original node but returns the updates explicitly. Pure.stateful can be use to recover the node.
from flax import nnx
import jax
import jax.numpy as jnp
class StatefulLinear(nnx.Module):
def __init__(self, din, dout, rngs):
self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32))
def increment(self):
self.count.value += 1
def __call__(self, x):
self.increment()
return x @ self.w + self.b
linear = StatefulLinear(3, 2, nnx.Rngs(0))
pure_linear = nnx.pure(linear)
@jax.jit
def forward(x, pure_linear):
y, pure_linear = pure_linear(x)
return y, pure_linear
x = jnp.ones((1, 3))
y, pure_linear = forward(x, pure_linear)
y, pure_linear = forward(x, pure_linear)
linear = pure_linear.stateful()
assert linear.count.value == 2
What does this PR do?
Adds
nnx.pure
which behaves similar tonnx.split
but returns aPure
pytree that can call anymethods from the original node but returns the updates explicitly.Pure.stateful
can be use to recover the node.