google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.15k stars 648 forks source link

[nnx] fix custom_vjp #4306

Closed cgarciae closed 3 weeks ago

cgarciae commented 1 month ago

What does this PR do?

Solves #4265.

hrbigelow commented 1 month ago

Hi @cgarciae - I'm excited to try this fix!

Just wanted to mention a slight discrepancy in jax vs. nnx custom_vjp signatures that are expected:

nnx bwd: res, (ins_g, outs_g) -> tangent jax bwd: res, outs_g -> tangent

The docs here seem t obe showing an example of jax.custom_vjp.

cgarciae commented 4 weeks ago

@hrbigelow this is correct! I expanded a bit more on the docs to reflect this.

BTW: thanks for the wait. I've been trying to get this transform right but its taken some rewrites to solve some edge cases.