patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
265 stars 12 forks source link

Can't vmap across input using Gauss Newton fwd #63

Open packquickly opened 1 month ago

packquickly commented 1 month ago

Vmapping across y0 with any method using AbstractGaussNewton throws a TypeError. MWE

import jax
import jax.numpy as jnp

import optimistix as optx

def rosenbrock(x, args):
    del args
    term1 = 10 * (x[1:] - x[:-1] ** 2)
    term2 = x - 1
    return term1, term2

inits = jnp.zeros((4, 10))
solve = lambda x: optx.least_squares(rosenbrock, optx.LevenbergMarquardt(1e-8, 1e-9), x)
out = jax.vmap(solve)(inits)  # throws error

The reason of this looks to be that the state includes an f_info with a FunctionLinearOperator whose linearised function is a Jaxpr which can't be batched over.

johannahaffner commented 1 month ago

Hi Jason,

haha, I had the same thing on my To-Do list and just wrote an MWE. I came across this thing a few weeks ago, it would sure be handy to be able to vmap over initial conditions and then check stats, for example in a multi-start scenario. (Not super urgent for me.) My workaround has so far been to only return solution.value, with an optional return of solution.stats, and to ignore the rest.

Why does the solution object need to include the whole Jaxpr to begin with? It sure seems useful to be able to inspect that during debugging. But otherwise, I would have no need to look at it - maybe it could be made optional?. (I optimize over the parameters of ODE models, so my jaxprs are always super long.)

In this case, the jaxpr contains the following useful message

nonbatchable[
    allow_constant_across_batch=True
    msg=Nonconstant batch. `equinox.internal.while_loop` has received a batch of values that 
    were expected to be constant. This is probably an internal error in the library you are using.
]

However, the code does not fail with that message. It fails twice - once with

.../site-packages/jax/_src/interpreters/batching.py:1107], in matchaxis(axis_name, sz, src, dst, x, sum_match)
   1105   _ = core.get_aval(x)
   1106 except TypeError as e:
-> 1107   raise TypeError(f"Output from batched function {x!r} with type "
   1108                   f"{type(x)} is not a valid JAX type") from e

And once with

.../site-packages/jax/_src/core.py:1455], in concrete_aval(x)
   1454   return concrete_aval(x.__jax_array__())
-> 1455 raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX "
   1456                  "type")

For what its worth, here is the MWE - even if it is now probably redundant :) I added BFGS as a solver.

Edit: print the number of lines in the jaxpr (1.2k, more than 100k characters.) Edit Nr. 2: add forward- and backward options. Edit Nr. 3: condense MWE. (Also removed line count for jaxpr.)

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

from jaxtyping import Array

import equinox as eqx
import diffrax as dfx
import optimistix as optx

import pytest

class ToyModel(eqx.Module):
    _term: dfx.ODETerm

    def __init__(self):
        def dydt(t, y, k):  # Monoexponential decay
            return - k * y
        self._term = dfx.ODETerm(dydt)

    def __call__(self, param):
        t0 = 0.
        t1 = 10.
        dt0 = 0.01
        y0 = jnp.array([10.])

        sol = dfx.diffeqsolve(
            self._term, 
            dfx.Tsit5(), 
            t0, t1, dt0, y0, args=param,
            saveat=dfx.SaveAt(ts=jnp.linspace(t0, t1)),
            adjoint=dfx.DirectAdjoint(),  # Supports both fwd and bwd autodiff
        )
        return sol.ys

def estimate_parameters(initial_guess, model, data, solver, solver_options: dict = dict(jac="fwd")):
    """Function that estimates the parameters."""

    def residuals(param, args):
        model, data = args
        fit = model(param)
        res = data - fit
        return res

    sol = optx.least_squares(
        residuals, 
        solver, 
        initial_guess,
        args = (model, data),
        options=solver_options,
    )
    return sol

model = ToyModel()
k = jnp.array([0.5])  # True value
ode_solution = model(k)

k0s = jnp.transpose(jnp.array([jnp.arange(0.0, 0.45, 0.05)]))

bfgs = optx.BFGS(atol=1e-09, rtol=1e-06)
lm = optx.LevenbergMarquardt(atol=1e-09, rtol=1e-06)
gn = optx.GaussNewton(atol=1e-09, rtol=1e-06)

# Add a vmap on top
vmapped_fwd_solve = jax.vmap(estimate_parameters, in_axes=(0, None, None, None))
vmapped_bwd_solve = jax.vmap(jtu.Partial(estimate_parameters, solver_options=dict(jac="bwd")), in_axes=(0, None, None, None))

for solver in [bfgs, lm, gn]:
    sol_bwd = vmapped_bwd_solve(k0s, model, ode_solution, solver)
    assert isinstance(sol_bwd, optx.Solution)

    if solver == bfgs:
        sol_fwd = vmapped_fwd_solve(k0s, model, ode_solution, solver)
        assert isinstance(sol_fwd, optx.Solution)
    else:
        with pytest.raises(TypeError):
            sol_fwd = vmapped_fwd_solve(k0s, model, ode_solution, solver)
print("Checks passed, expected errors raised. This is what happens: ")

# Repeat one of the calls that raises the error
vmapped_fwd_solve(k0s, model, ode_solution, lm)  # This fails
johannahaffner commented 1 month ago

I realise this is a bit long for an MWE - only the last line fails.

johannahaffner commented 1 month ago

I dug a little into _make_f_info from the gauss_newton module and it seems to me that the lambda function in jax.linearize(...) is redundant. If you print the returned linearized function lin_fn, you see that it has a jaxpr - which the output of pure jax.linearize(...) does not.

What I do not understand is why there has to be an auxiliary argument, but I worked around it with a wrapper for now:

import pytest

import jax
import jax.numpy as jnp
from jax.core import Jaxpr

import equinox as eqx
from optimistix._solver.gauss_newton import _make_f_info

def _for_jacrev(_y):
    """Copied from: optimistix._solver.gauss_newton"""
    f_eval, aux_eval = fn(_y, args)  # Why does tnis assume an auxiliary output?
    return f_eval, (f_eval, aux_eval)

def _no_nan(x):
    """Compied from test/helpers.py in diffrax."""
    if eqx.is_array(x):
        return x.at[jnp.isnan(x)].set(8.9568)  # arbitrary magic value
    else:
        return x

def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8, equal_nan=False):
    """Copied from test/helpers.py in diffrax."""
    if equal_nan:
        x = jtu.tree_map(_no_nan, x)
        y = jtu.tree_map(_no_nan, y)
    return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)

def residuals(origin, args):
    del args
    def shifted_parabola(x0):
        x = jnp.linspace(0, 10)
        return (x - x0)**2
    true = shifted_parabola(2)
    fit = shifted_parabola(origin)
    return true - fit

def aux_wrapper(origin, args):
    return residuals(origin, args), None

initial_guess = 1.
args = ()
jac_bwd = jax.jacrev(residuals)(initial_guess, args)
jac_fwd = jax.jacfwd(residuals)(initial_guess, args)
assert tree_allclose(jac_bwd, jac_fwd)

with pytest.raises(ValueError):  # Only works with aux wrapper
    _make_f_info(residuals, initial_guess, args, set(), "bwd")

# Now with auxiliary output wrapper
(residual_jac_bwd, _) = _make_f_info(aux_wrapper, initial_guess, args, set(), "bwd")
assert tree_allclose(residual_jac_bwd.jac.pytree, jac_bwd)
(residual_jac_fwd, _) = _make_f_info(aux_wrapper, initial_guess, args, set(), "fwd")
assert isinstance(residual_jac_fwd.jac.fn.jaxpr, Jaxpr)

# The following snipped (jax.linearize...) is copied from _make_f_info (line 174)
with pytest.raises(ValueError):  # Again, must have aux
    f_eval, lin_fn, aux_eval = jax.linearize(lambda _y: aux_wrapper(_y, args), initial_guess, has_aux=False)
    f_eval, lin_fn, aux_eval = jax.linearize(lambda _y: aux_wrapper(_y, args), initial_guess)
    f_eval, lin_fn, aux_eval = jax.linearize(lambda _y: residuals(_y, args), initial_guess)  # Value error: residuals does not have aux

f_eval, lin_fn, aux_eval = jax.linearize(lambda _y: aux_wrapper(_y, args), initial_guess, has_aux=True)
assert aux_eval is None
assert tree_allclose(f_eval, residuals(initial_guess, args))
print(lin_fn)  # lin_fn has jaxpr

# Compare to jax.linearize without the lambda function
res, residuals_jvp = jax.linearize(residuals, *(initial_guess, args))
assert tree_allclose(res, residuals(initial_guess, args))
assert tree_allclose(jac_bwd, residuals_jvp(initial_guess, args))

res, residuals_jvp, aux = jax.linearize(aux_wrapper, *(initial_guess, args), has_aux=True)
assert tree_allclose(res, residuals(initial_guess, args))
assert tree_allclose(jac_bwd, residuals_jvp(initial_guess, args))
assert aux is None
johannahaffner commented 1 month ago

By now you can skip most of my thought process above :) I believe it is a somewhat subtle issue involving the eval shapes passed to the linear operator in _make_f_info from gauss_newton.py. The jvps computed using jax.linearize contain jaxprs in both cases, but one evaluates to the correct jacobian and one does not.

Using the MWE below, I get

TypeError: Expected PyTreeDef((*, ())), got PyTreeDef(((*, ()),))

MWE

import pytest

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

import equinox as eqx
import lineax as lx
from optimistix._solver.gauss_newton import _make_f_info

def _no_nan(x):
    """Compied from test/helpers.py in diffrax."""
    if eqx.is_array(x):
        return x.at[jnp.isnan(x)].set(8.9568)  # arbitrary magic value
    else:
        return x
def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8, equal_nan=False):
    """Copied from test/helpers.py in diffrax."""
    if equal_nan:
        x = jtu.tree_map(_no_nan, x)
        y = jtu.tree_map(_no_nan, y)
    return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)

def fn(y, args): # Optimistix insists on aux, it seems: return extra None
    del args
    def shifted_parabola(x0):
        x = jnp.linspace(0, 10)
        return (x - x0)**2
    true = shifted_parabola(2)
    fit = shifted_parabola(y)
    return true - fit, 0  # Return zero as aux

y0 = 1.
nothing = ()

# Compute jacobian two ways, ignore fn_eval, aux_eval
_, fn_j, _ = jax.linearize(lambda _y: fn(_y, nothing), y0, has_aux=True)  # Status quo
_, fn_j_no_lambda, _ = jax.linearize(fn, *(y0, nothing), has_aux=True)

# Now check the Jacobians
true_jac_eval, _ = jax.jacfwd(fn)(y0, nothing)  # Throw away aux_eval
with pytest.raises(TypeError):
    assert tree_allclose(fn_j(y0, nothing), true_jac_eval)  # lambda used: input is unexpected pytree 
assert tree_allclose(fn_j_no_lambda(y0, nothing), true_jac_eval)

# Check eval shapes
fn_eval_shape, aux_eval_shape = jax.eval_shape(fn, *(y0, nothing))  # Returns tuple that includes aux
fn_j_no_lambda_eval_shape = jax.eval_shape(fn_j_no_lambda, *(y0, nothing))
assert tree_allclose(fn_eval_shape, fn_j_no_lambda_eval_shape)
with pytest.raises(TypeError):
    jax.eval_shape(fn_j, *(y0, nothing))  # unexpected pytree, again

# Create lx.FunctionLinearOperator
lx.FunctionLinearOperator(fn_j_no_lambda, jax.eval_shape(lambda x: x, (y0, nothing)))
johannahaffner commented 1 month ago

Related: I opened https://github.com/google/jax/issues/21581

johannahaffner commented 1 month ago

Here is my latest iteration, still poking at the two lines in _make_f_info from Gauss Newton. I could show that the jaxpr is not causing the issue - at least not outside of FunctionLinearOperator, where vmapping over a jacobian that contains a jaxpr and is an output of jax.linearize raises no errata.

import pytest

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

import equinox as eqx
import lineax as lx

def _no_nan(x):
    """Compied from test/helpers.py in diffrax."""
    if eqx.is_array(x):
        return x.at[jnp.isnan(x)].set(8.9568)  # arbitrary magic value
    else:
        return x
def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8, equal_nan=False):
    """Copied from test/helpers.py in diffrax."""
    if equal_nan:
        x = jtu.tree_map(_no_nan, x)
        y = jtu.tree_map(_no_nan, y)
    return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol)

def fn(y):
    def shifted_parabola(x0):
        x = jnp.linspace(0, 10)
        return (x - x0)**2
    true = shifted_parabola(2.)  # True value
    fit = shifted_parabola(y)
    return true - fit

def aux_wrapper(y):
    return fn(y), None

y0 = 1.  # starting guess
y0s = jnp.arange(0., 4., 0.1)  # Many initial values

# Get jacobians the simple way
_, jac_of_fn = jax.linearize(fn, y0)
_, jac_of_aux_wrapper, _ = jax.linearize(aux_wrapper, y0, has_aux=True)
assert tree_allclose(jax.jacfwd(fn)(y0), jac_of_fn(y0)) 
assert tree_allclose(jax.jacfwd(fn)(y0), jac_of_aux_wrapper(y0))  

vmapped_jac_of_fn = jax.vmap(jac_of_fn)(y0s) # Does not raise error
vmapped_jac_of_aux_wrapper = jax.vmap(jac_of_aux_wrapper)(y0s) # Does not raise error

# Now include the lambda function in jax.linearize (status quo in optimistix)   
_, jac_of_fn_with_lambda = jax.linearize(lambda _y: fn(_y), y0)
_, jac_of_aux_wrapper_with_lambda, _ = jax.linearize(lambda _y: aux_wrapper(_y), y0, has_aux=True)

vmapped_jac_of_fn_with_lambda = jax.vmap(jac_of_fn_with_lambda)(y0s) # Does not raise error
vmapped_jac_of_aux_wrapper_with_lambda = jax.vmap(jac_of_aux_wrapper_with_lambda)(y0s) # Does not raise error

# Context: using lambda functions produces subtle difference in pytrees, not legible when examining pytreedef (as a human)
with pytest.raises(AssertionError):
    assert jtu.tree_structure(jac_of_fn) == jtu.tree_structure(jac_of_fn_with_lambda)
assert str(jtu.tree_structure(jac_of_fn)) == str(jtu.tree_structure(jac_of_fn_with_lambda))

# Create a lineax Linear Operator
def lin_fun(y):
    return 2 * y
lin_op = lx.FunctionLinearOperator(lin_fun, jax.eval_shape(lin_fun, y0))  # Confirm that it works in this case

with pytest.raises(ValueError): # I don't understand why it does not work in these cases
    lx.FunctionLinearOperator(jac_of_fn, jax.eval_shape(jac_of_fn, y0))
    lx.FunctionLinearOperator(jac_of_fn_with_lambda, jax.eval_shape(jac_of_fn_with_lambda, y0))
tjltjl commented 1 month ago

A simple workaround: dataclasses.replace() the offending member of the output

johannahaffner commented 1 month ago

The tricky thing is that I can't figure out what the offending member is.

johannahaffner commented 1 month ago

And it has now been shown that it is a deeper issue in jax.linearize, which produces pytrees with nonidentical structure even for identical input functions, called with identical inputs.

patrick-kidger commented 1 month ago

I think I understand what's going on here. The output of optx.least_squares includes a jaxpr inside of out.state. This isn't an arraylike object, so JAX doesn't understand how to handle it as an output of the vmap. Morally speaking, what's going on here is the same as jax.vmap(lambda x: object())(...), in which again non-array-like object is being returned.

The solution is pretty simple: use eqx.filter_vmap instead. This passes through all non-array-like objects unchanged. Indeec the use case in this issue is the raison d'etre of eqx.filter_vmap!

Does this solve the issues everyone is facing?

johannahaffner commented 4 weeks ago

Oh dear :D

It does solve my issue. I was actually in the process of replacing all vmaps with filter_vmaps, but there were still some around. Not anymore, though!