Closed rodrigodzf closed 2 days ago
Here's an example from the docs:
import jax
import jax.numpy as jnp
from flax import nnx
class Foo(nnx.Module):
def __init__(self, x, y):
self.x = nnx.Param(x)
self.y = nnx.Param(y)
@nnx.custom_vjp
def f(m: Foo):
return jnp.sin(m.x) * m.y
def f_fwd(m: Foo):
return f(m), (jnp.cos(m.x), jnp.sin(m.x), m)
def f_bwd(res, g):
inputs_g, out_g = g
cos_x, sin_x, m = res
tangent_m = nnx.State(dict(x=cos_x * out_g * m.y, y=sin_x * out_g))
return (tangent_m,)
f.defvjp(f_fwd, f_bwd)
m = Foo(x=jnp.array(1.), y=jnp.array(2.))
grads = nnx.grad(f)(m)
jax.tree.map(jnp.shape, grads)
State({
'x': VariableState(
type=Param,
value=()
),
'y': VariableState(
type=Param,
value=()
)
})
Excellent! Thank you!
I wanted to know how to properly use
custom_vjp
with NNX. I could only find the docstring for thelinen
version, but no docstring for the NNX one.