Fixing all Losses to return everything in non-auxiliary data during flattening, to avoid any tracer leaks when the weight is dynamic.
Simplifying the LossFunction flattening/unflattering to use the parameter_dependents and parameter_independentants.
Modifying the tracer, to go around a recently introduced bug in Jax that makes the current code throw a leaker error. This is done by never returning an actual LossFunction object, but instead their inputs, and reconstructing them after that outside of any calls to jax.vjp/jx.jvp.