The bit that matters here is that the f_eval_info.jac in
AbstractGaussNewton.step now throws away its static (non-array) parts
of its PyTree, and instead uses the equivalent static (non-array) parts
of state.f_info.jac, i.e. as were computed in
AbstractGaussNewton.init.
Now at a logical level this shouldn't matter at all: the static pieces
should be the same in both cases, as they're just the output of
_make_f_info with similarly-structured inputs.
However, _make_f_info calls lx.FunctionLinearOperator which calls
eqx.filter_closure_convert which calls jax.make_jaxpr which returns
a jaxpr... and so between the two calls to _make_f_info, we actually
end up with two jaxprs. Both encode the same program, but are two
different Python objects. Now jaxprs have __eq__ defined according to
identity, so these two (functionally identical) jaxprs do not compare
as equal.
Previously we worked around this inside _iterate.py: we carefully
removed or wrapped any jaxprs before anything that would try to compare
them for equality. This was a bit ugly, but it worked.
However, it turns out that this still left a problem when manually
stepping an Optimistix solver! (In a way akin to an Optax solver:
something like
@eqx.filter_jit
def make_step(...):
... = solver.step(...)
for ... in ...: # Python level for-loop
... = make_step(...)
)
then in fact on every iteration of the Python loop, we would end up
recompiling, as we always gets a new jaxpr at
state # state for the Gauss-Newton solver
.f_info # as returned by _make_f_info
.jac # the FunctionLinearOperator
.fn # the closure-converted function
.jaxpr # the jaxpr from the closure conversion
!
Now one fix is simply to demand that manually stepping a solver
requires similar hackery as we had in _iterate.py. But maybe enough
is enough, and we should try doing something better instead: that is,
we do what this PR does, and just preserves the same jaxpr all the way
through.
For bonus points, this means that we can now remove our special jaxpr
handling from _iterate.py (and from filter_cond, which also needed
this for the same reason).
Finally, you might be wondering: why do we need to trace two equivalent
jaxprs at all? This seems inefficient -- can we arrange to trace it
just once? The answer is "probably, but not in this PR". This seems to
require that (a) Lineax offer a way to turn off closure conversion
(done in https://github.com/google/lineax/pull/71), but that (b) when
using this, this still seems to trigger a similar issue in JAX, that
the primal and tangent results from jax.custom_jvp match. So for now
this is just something to try and tackle later -- once we do, we'll get
slightly better compile times.
This is quite an important fix!
The bit that matters here is that the
f_eval_info.jac
inAbstractGaussNewton.step
now throws away its static (non-array) parts of its PyTree, and instead uses the equivalent static (non-array) parts ofstate.f_info.jac
, i.e. as were computed inAbstractGaussNewton.init
.Now at a logical level this shouldn't matter at all: the static pieces should be the same in both cases, as they're just the output of
_make_f_info
with similarly-structured inputs.However,
_make_f_info
callslx.FunctionLinearOperator
which callseqx.filter_closure_convert
which callsjax.make_jaxpr
which returns a jaxpr... and so between the two calls to_make_f_info
, we actually end up with two jaxprs. Both encode the same program, but are two different Python objects. Now jaxprs have__eq__
defined according to identity, so these two (functionally identical) jaxprs do not compare as equal.Previously we worked around this inside
_iterate.py
: we carefully removed or wrapped any jaxprs before anything that would try to compare them for equality. This was a bit ugly, but it worked.However, it turns out that this still left a problem when manually stepping an Optimistix solver! (In a way akin to an Optax solver: something like
) then in fact on every iteration of the Python loop, we would end up recompiling, as we always gets a new jaxpr at
!
Now one fix is simply to demand that manually stepping a solver requires similar hackery as we had in
_iterate.py
. But maybe enough is enough, and we should try doing something better instead: that is, we do what this PR does, and just preserves the same jaxpr all the way through.For bonus points, this means that we can now remove our special jaxpr handling from
_iterate.py
(and fromfilter_cond
, which also needed this for the same reason).Finally, you might be wondering: why do we need to trace two equivalent jaxprs at all? This seems inefficient -- can we arrange to trace it just once? The answer is "probably, but not in this PR". This seems to require that (a) Lineax offer a way to turn off closure conversion (done in https://github.com/google/lineax/pull/71), but that (b) when using this, this still seems to trigger a similar issue in JAX, that the primal and tangent results from
jax.custom_jvp
match. So for now this is just something to try and tackle later -- once we do, we'll get slightly better compile times.