google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.03k stars 637 forks source link

CustomVJP example with NNX #4265

Closed rodrigodzf closed 2 days ago

rodrigodzf commented 3 days ago

I wanted to know how to properly use custom_vjp with NNX. I could only find the docstring for the linen version, but no docstring for the NNX one.

cgarciae commented 2 days ago

Here's an example from the docs:

import jax
import jax.numpy as jnp
from flax import nnx

class Foo(nnx.Module):
  def __init__(self, x, y):
    self.x = nnx.Param(x)
    self.y = nnx.Param(y)

@nnx.custom_vjp
def f(m: Foo):
  return jnp.sin(m.x) * m.y

def f_fwd(m: Foo):
  return f(m), (jnp.cos(m.x), jnp.sin(m.x), m)

def f_bwd(res, g):
  inputs_g, out_g = g
  cos_x, sin_x, m = res
  tangent_m = nnx.State(dict(x=cos_x * out_g * m.y, y=sin_x * out_g))
  return (tangent_m,)

f.defvjp(f_fwd, f_bwd)

m = Foo(x=jnp.array(1.), y=jnp.array(2.))
grads = nnx.grad(f)(m)

jax.tree.map(jnp.shape, grads)
State({
  'x': VariableState(
    type=Param,
    value=()
  ),
  'y': VariableState(
    type=Param,
    value=()
  )
})
rodrigodzf commented 2 days ago

Excellent! Thank you!