google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.15k stars 649 forks source link

How to use nnx.custom_vjp with non-class arguments? Example needed #4301

Closed hrbigelow closed 2 weeks ago

hrbigelow commented 1 month ago

Hi @cgarciae

I wondered if it's possible to apply an nnx.custom_vjp to a function like:

@nnx.custom_vjp
def linear(m: MyLinear, x: jax.Array) -> jax.Array:
    y = x @ m.kernel + m.bias
    return y

But, I'm not sure what linear_bwd should return. I tried this:

def linear_fwd(m: nnx.Linear, x: jax.Array):
    return linear(m, x), (m, x)

def linear_bwd(res, g):
    m, x = res
    inputs_g, outputs_g = g
    kernel_grad = outputs_g[None,:] * x[:,None]
    bias_grad = outputs_g
    x_grad = m.kernel @ outputs_g
    assert x_grad.shape == x.shape, 'Shape mismatch for x'
    assert m.kernel.value.shape == kernel_grad.shape, 'Shape mismatch for kernel'
    assert m.bias.value.shape == bias_grad.shape, 'Shape mismatch for bias'
    m_g = nnx.State(dict(kernel=kernel_grad, bias=bias_grad))
    x_g = nnx.State((x_grad,))
    # x_g = nnx.State(dict(x=x_grad)) # also tried this
    return (m_g, x_g)

The notebook is here

Ultimately I want to have an nnx.Module whose __call__ method is such that the module parameters are updated during the backward pass. Any guidance would be greatly appreciated!

Best,

Henry

cgarciae commented 1 month ago

Hey @hrbigelow, seems we have a bug. Thanks for reporting this!

The tangets for Module's should be State objects, however you are running into something else. custom_vjp is very new and one of the trickiest transforms so I'd expect a few hiccups along the way.

hrbigelow commented 1 month ago

Hi @cgarciae,

Thanks for looking into this. I can see this would be very tricky. I think it will be extremely useful if it can open up an opportunity to express different learning algorithms with minimal code. I look forward to using it.

By the way for context I did open a Jax discussion before I was aware of Flax NNX. (Had been using Haiku).

hrbigelow commented 3 weeks ago

@cgarciae feel free to close this if you like. I've been using the fix and it works great.