Open packquickly opened 1 month ago
Hi Jason,
haha, I had the same thing on my To-Do list and just wrote an MWE. I came across this thing a few weeks ago, it would sure be handy to be able to vmap over initial conditions and then check stats, for example in a multi-start scenario. (Not super urgent for me.) My workaround has so far been to only return solution.value
, with an optional return of solution.stats
, and to ignore the rest.
Why does the solution object need to include the whole Jaxpr to begin with? It sure seems useful to be able to inspect that during debugging. But otherwise, I would have no need to look at it - maybe it could be made optional?. (I optimize over the parameters of ODE models, so my jaxprs are always super long.)
In this case, the jaxpr contains the following useful message
nonbatchable[
allow_constant_across_batch=True
msg=Nonconstant batch. `equinox.internal.while_loop` has received a batch of values that
were expected to be constant. This is probably an internal error in the library you are using.
]
However, the code does not fail with that message. It fails twice - once with
.../site-packages/jax/_src/interpreters/batching.py:1107], in matchaxis(axis_name, sz, src, dst, x, sum_match)
1105 _ = core.get_aval(x)
1106 except TypeError as e:
-> 1107 raise TypeError(f"Output from batched function {x!r} with type "
1108 f"{type(x)} is not a valid JAX type") from e
And once with
.../site-packages/jax/_src/core.py:1455], in concrete_aval(x)
1454 return concrete_aval(x.__jax_array__())
-> 1455 raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX "
1456 "type")
For what its worth, here is the MWE - even if it is now probably redundant :) I added BFGS
as a solver.
Edit: print the number of lines in the jaxpr (1.2k, more than 100k characters.) Edit Nr. 2: add forward- and backward options. Edit Nr. 3: condense MWE. (Also removed line count for jaxpr.)
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import Array
import equinox as eqx
import diffrax as dfx
import optimistix as optx
import pytest
class ToyModel(eqx.Module):
_term: dfx.ODETerm
def __init__(self):
def dydt(t, y, k): # Monoexponential decay
return - k * y
self._term = dfx.ODETerm(dydt)
def __call__(self, param):
t0 = 0.
t1 = 10.
dt0 = 0.01
y0 = jnp.array([10.])
sol = dfx.diffeqsolve(
self._term,
dfx.Tsit5(),
t0, t1, dt0, y0, args=param,
saveat=dfx.SaveAt(ts=jnp.linspace(t0, t1)),
adjoint=dfx.DirectAdjoint(), # Supports both fwd and bwd autodiff
)
return sol.ys
def estimate_parameters(initial_guess, model, data, solver, solver_options: dict = dict(jac="fwd")):
"""Function that estimates the parameters."""
def residuals(param, args):
model, data = args
fit = model(param)
res = data - fit
return res
sol = optx.least_squares(
residuals,
solver,
initial_guess,
args = (model, data),
options=solver_options,
)
return sol
model = ToyModel()
k = jnp.array([0.5]) # True value
ode_solution = model(k)
k0s = jnp.transpose(jnp.array([jnp.arange(0.0, 0.45, 0.05)]))
bfgs = optx.BFGS(atol=1e-09, rtol=1e-06)
lm = optx.LevenbergMarquardt(atol=1e-09, rtol=1e-06)
gn = optx.GaussNewton(atol=1e-09, rtol=1e-06)
# Add a vmap on top
vmapped_fwd_solve = jax.vmap(estimate_parameters, in_axes=(0, None, None, None))
vmapped_bwd_solve = jax.vmap(jtu.Partial(estimate_parameters, solver_options=dict(jac="bwd")), in_axes=(0, None, None, None))
for solver in [bfgs, lm, gn]:
sol_bwd = vmapped_bwd_solve(k0s, model, ode_solution, solver)
assert isinstance(sol_bwd, optx.Solution)
if solver == bfgs:
sol_fwd = vmapped_fwd_solve(k0s, model, ode_solution, solver)
assert isinstance(sol_fwd, optx.Solution)
else:
with pytest.raises(TypeError):
sol_fwd = vmapped_fwd_solve(k0s, model, ode_solution, solver)
print("Checks passed, expected errors raised. This is what happens: ")
# Repeat one of the calls that raises the error
vmapped_fwd_solve(k0s, model, ode_solution, lm) # This fails
I realise this is a bit long for an MWE - only the last line fails.
I dug a little into _make_f_info
from the gauss_newton
module and it seems to me that the lambda
function in jax.linearize(...)
is redundant. If you print the returned linearized function lin_fn
, you see that it has a jaxpr - which the output of pure jax.linearize(...)
does not.
What I do not understand is why there has to be an auxiliary argument, but I worked around it with a wrapper for now:
import pytest
import jax
import jax.numpy as jnp
from jax.core import Jaxpr
import equinox as eqx
from optimistix._solver.gauss_newton import _make_f_info
def _for_jacrev(_y):
"""Copied from: optimistix._solver.gauss_newton"""
f_eval, aux_eval = fn(_y, args) # Why does tnis assume an auxiliary output?
return f_eval, (f_eval, aux_eval)
def _no_nan(x):
"""Compied from test/helpers.py in diffrax."""
if eqx.is_array(x):
return x.at[jnp.isnan(x)].set(8.9568) # arbitrary magic value
else:
return x
def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8, equal_nan=False):
"""Copied from test/helpers.py in diffrax."""
if equal_nan:
x = jtu.tree_map(_no_nan, x)
y = jtu.tree_map(_no_nan, y)
return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)
def residuals(origin, args):
del args
def shifted_parabola(x0):
x = jnp.linspace(0, 10)
return (x - x0)**2
true = shifted_parabola(2)
fit = shifted_parabola(origin)
return true - fit
def aux_wrapper(origin, args):
return residuals(origin, args), None
initial_guess = 1.
args = ()
jac_bwd = jax.jacrev(residuals)(initial_guess, args)
jac_fwd = jax.jacfwd(residuals)(initial_guess, args)
assert tree_allclose(jac_bwd, jac_fwd)
with pytest.raises(ValueError): # Only works with aux wrapper
_make_f_info(residuals, initial_guess, args, set(), "bwd")
# Now with auxiliary output wrapper
(residual_jac_bwd, _) = _make_f_info(aux_wrapper, initial_guess, args, set(), "bwd")
assert tree_allclose(residual_jac_bwd.jac.pytree, jac_bwd)
(residual_jac_fwd, _) = _make_f_info(aux_wrapper, initial_guess, args, set(), "fwd")
assert isinstance(residual_jac_fwd.jac.fn.jaxpr, Jaxpr)
# The following snipped (jax.linearize...) is copied from _make_f_info (line 174)
with pytest.raises(ValueError): # Again, must have aux
f_eval, lin_fn, aux_eval = jax.linearize(lambda _y: aux_wrapper(_y, args), initial_guess, has_aux=False)
f_eval, lin_fn, aux_eval = jax.linearize(lambda _y: aux_wrapper(_y, args), initial_guess)
f_eval, lin_fn, aux_eval = jax.linearize(lambda _y: residuals(_y, args), initial_guess) # Value error: residuals does not have aux
f_eval, lin_fn, aux_eval = jax.linearize(lambda _y: aux_wrapper(_y, args), initial_guess, has_aux=True)
assert aux_eval is None
assert tree_allclose(f_eval, residuals(initial_guess, args))
print(lin_fn) # lin_fn has jaxpr
# Compare to jax.linearize without the lambda function
res, residuals_jvp = jax.linearize(residuals, *(initial_guess, args))
assert tree_allclose(res, residuals(initial_guess, args))
assert tree_allclose(jac_bwd, residuals_jvp(initial_guess, args))
res, residuals_jvp, aux = jax.linearize(aux_wrapper, *(initial_guess, args), has_aux=True)
assert tree_allclose(res, residuals(initial_guess, args))
assert tree_allclose(jac_bwd, residuals_jvp(initial_guess, args))
assert aux is None
By now you can skip most of my thought process above :) I believe it is a somewhat subtle issue involving the eval shapes passed to the linear operator in _make_f_info
from gauss_newton.py
. The jvp
s computed using jax.linearize
contain jaxpr
s in both cases, but one evaluates to the correct jacobian and one does not.
Using the MWE below, I get
TypeError: Expected PyTreeDef((*, ())), got PyTreeDef(((*, ()),))
MWE
import pytest
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import equinox as eqx
import lineax as lx
from optimistix._solver.gauss_newton import _make_f_info
def _no_nan(x):
"""Compied from test/helpers.py in diffrax."""
if eqx.is_array(x):
return x.at[jnp.isnan(x)].set(8.9568) # arbitrary magic value
else:
return x
def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8, equal_nan=False):
"""Copied from test/helpers.py in diffrax."""
if equal_nan:
x = jtu.tree_map(_no_nan, x)
y = jtu.tree_map(_no_nan, y)
return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)
def fn(y, args): # Optimistix insists on aux, it seems: return extra None
del args
def shifted_parabola(x0):
x = jnp.linspace(0, 10)
return (x - x0)**2
true = shifted_parabola(2)
fit = shifted_parabola(y)
return true - fit, 0 # Return zero as aux
y0 = 1.
nothing = ()
# Compute jacobian two ways, ignore fn_eval, aux_eval
_, fn_j, _ = jax.linearize(lambda _y: fn(_y, nothing), y0, has_aux=True) # Status quo
_, fn_j_no_lambda, _ = jax.linearize(fn, *(y0, nothing), has_aux=True)
# Now check the Jacobians
true_jac_eval, _ = jax.jacfwd(fn)(y0, nothing) # Throw away aux_eval
with pytest.raises(TypeError):
assert tree_allclose(fn_j(y0, nothing), true_jac_eval) # lambda used: input is unexpected pytree
assert tree_allclose(fn_j_no_lambda(y0, nothing), true_jac_eval)
# Check eval shapes
fn_eval_shape, aux_eval_shape = jax.eval_shape(fn, *(y0, nothing)) # Returns tuple that includes aux
fn_j_no_lambda_eval_shape = jax.eval_shape(fn_j_no_lambda, *(y0, nothing))
assert tree_allclose(fn_eval_shape, fn_j_no_lambda_eval_shape)
with pytest.raises(TypeError):
jax.eval_shape(fn_j, *(y0, nothing)) # unexpected pytree, again
# Create lx.FunctionLinearOperator
lx.FunctionLinearOperator(fn_j_no_lambda, jax.eval_shape(lambda x: x, (y0, nothing)))
Related: I opened https://github.com/google/jax/issues/21581
Here is my latest iteration, still poking at the two lines in _make_f_info
from Gauss Newton.
I could show that the jaxpr is not causing the issue - at least not outside of FunctionLinearOperator
, where vmapping over a jacobian that contains a jaxpr and is an output of jax.linearize
raises no errata.
import pytest
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import equinox as eqx
import lineax as lx
def _no_nan(x):
"""Compied from test/helpers.py in diffrax."""
if eqx.is_array(x):
return x.at[jnp.isnan(x)].set(8.9568) # arbitrary magic value
else:
return x
def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8, equal_nan=False):
"""Copied from test/helpers.py in diffrax."""
if equal_nan:
x = jtu.tree_map(_no_nan, x)
y = jtu.tree_map(_no_nan, y)
return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)
def fn(y):
def shifted_parabola(x0):
x = jnp.linspace(0, 10)
return (x - x0)**2
true = shifted_parabola(2.) # True value
fit = shifted_parabola(y)
return true - fit
def aux_wrapper(y):
return fn(y), None
y0 = 1. # starting guess
y0s = jnp.arange(0., 4., 0.1) # Many initial values
# Get jacobians the simple way
_, jac_of_fn = jax.linearize(fn, y0)
_, jac_of_aux_wrapper, _ = jax.linearize(aux_wrapper, y0, has_aux=True)
assert tree_allclose(jax.jacfwd(fn)(y0), jac_of_fn(y0))
assert tree_allclose(jax.jacfwd(fn)(y0), jac_of_aux_wrapper(y0))
vmapped_jac_of_fn = jax.vmap(jac_of_fn)(y0s) # Does not raise error
vmapped_jac_of_aux_wrapper = jax.vmap(jac_of_aux_wrapper)(y0s) # Does not raise error
# Now include the lambda function in jax.linearize (status quo in optimistix)
_, jac_of_fn_with_lambda = jax.linearize(lambda _y: fn(_y), y0)
_, jac_of_aux_wrapper_with_lambda, _ = jax.linearize(lambda _y: aux_wrapper(_y), y0, has_aux=True)
vmapped_jac_of_fn_with_lambda = jax.vmap(jac_of_fn_with_lambda)(y0s) # Does not raise error
vmapped_jac_of_aux_wrapper_with_lambda = jax.vmap(jac_of_aux_wrapper_with_lambda)(y0s) # Does not raise error
# Context: using lambda functions produces subtle difference in pytrees, not legible when examining pytreedef (as a human)
with pytest.raises(AssertionError):
assert jtu.tree_structure(jac_of_fn) == jtu.tree_structure(jac_of_fn_with_lambda)
assert str(jtu.tree_structure(jac_of_fn)) == str(jtu.tree_structure(jac_of_fn_with_lambda))
# Create a lineax Linear Operator
def lin_fun(y):
return 2 * y
lin_op = lx.FunctionLinearOperator(lin_fun, jax.eval_shape(lin_fun, y0)) # Confirm that it works in this case
with pytest.raises(ValueError): # I don't understand why it does not work in these cases
lx.FunctionLinearOperator(jac_of_fn, jax.eval_shape(jac_of_fn, y0))
lx.FunctionLinearOperator(jac_of_fn_with_lambda, jax.eval_shape(jac_of_fn_with_lambda, y0))
A simple workaround: dataclasses.replace() the offending member of the output
The tricky thing is that I can't figure out what the offending member is.
And it has now been shown that it is a deeper issue in jax.linearize
, which produces pytrees with nonidentical structure even for identical input functions, called with identical inputs.
I think I understand what's going on here. The output of optx.least_squares
includes a jaxpr inside of out.state
. This isn't an arraylike object, so JAX doesn't understand how to handle it as an output of the vmap
. Morally speaking, what's going on here is the same as jax.vmap(lambda x: object())(...)
, in which again non-array-like object is being returned.
The solution is pretty simple: use eqx.filter_vmap
instead. This passes through all non-array-like objects unchanged. Indeec the use case in this issue is the raison d'etre of eqx.filter_vmap
!
Does this solve the issues everyone is facing?
Oh dear :D
It does solve my issue. I was actually in the process of replacing all vmap
s with filter_vmap
s, but there were still some around. Not anymore, though!
Vmapping across
y0
with any method usingAbstractGaussNewton
throws a TypeError. MWEThe reason of this looks to be that the state includes an
f_info
with a FunctionLinearOperator whose linearised function is a Jaxpr which can't be batched over.