google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.78k stars 610 forks source link

[nnx] add jit donation test #4022

Open cgarciae opened 1 week ago

cgarciae commented 1 week ago

What does this PR do?

Adds a test using nnx.jit + donate_argnums to showcase partial initialization.