Closed cgarciae closed 3 weeks ago
Hi @cgarciae - I'm excited to try this fix!
Just wanted to mention a slight discrepancy in jax vs. nnx custom_vjp signatures that are expected:
nnx bwd: res, (ins_g, outs_g) -> tangent jax bwd: res, outs_g -> tangent
The docs here seem t obe showing an example of jax.custom_vjp.
@hrbigelow this is correct! I expanded a bit more on the docs to reflect this.
BTW: thanks for the wait. I've been trying to get this transform right but its taken some rewrites to solve some edge cases.
What does this PR do?
FwdFn
is called outside of an update context. When this happensFwdFn
behaves as a pure function.Solves #4265.