Closed hrbigelow closed 2 weeks ago
Hey @hrbigelow, seems we have a bug. Thanks for reporting this!
The tangets for Module's should be State
objects, however you are running into something else.
custom_vjp
is very new and one of the trickiest transforms so I'd expect a few hiccups along the way.
Hi @cgarciae,
Thanks for looking into this. I can see this would be very tricky. I think it will be extremely useful if it can open up an opportunity to express different learning algorithms with minimal code. I look forward to using it.
By the way for context I did open a Jax discussion before I was aware of Flax NNX. (Had been using Haiku).
@cgarciae feel free to close this if you like. I've been using the fix and it works great.
Hi @cgarciae
I wondered if it's possible to apply an
nnx.custom_vjp
to a function like:But, I'm not sure what
linear_bwd
should return. I tried this:The notebook is here
Ultimately I want to have an
nnx.Module
whose__call__
method is such that the module parameters are updated during the backward pass. Any guidance would be greatly appreciated!Best,
Henry