google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.78k stars 610 forks source link

[nnx] pure API #4004

Open cgarciae opened 2 weeks ago

cgarciae commented 2 weeks ago

What does this PR do?

Adds nnx.pure which behaves similar to nnx.split but returns a Pure pytree that can call anymethods from the original node but returns the updates explicitly. Pure.stateful can be use to recover the node.

from flax import nnx
import jax
import jax.numpy as jnp

class StatefulLinear(nnx.Module):
  def __init__(self, din, dout, rngs):
    self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32))

  def increment(self):
    self.count.value += 1

  def __call__(self, x):
    self.increment()
    return x @ self.w + self.b

linear = StatefulLinear(3, 2, nnx.Rngs(0))
pure_linear = nnx.pure(linear)

@jax.jit
def forward(x, pure_linear):
  y, pure_linear = pure_linear(x)
  return y, pure_linear

x = jnp.ones((1, 3))
y, pure_linear = forward(x, pure_linear)
y, pure_linear = forward(x, pure_linear)

linear = pure_linear.stateful()
assert linear.count.value == 2