patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
330 stars 14 forks source link

pytree output structure mismatch error in backprop during vmap #48

Closed FFroehlich closed 7 months ago

FFroehlich commented 8 months ago

I am running into ValueError: pytree does not match out_structure errors when computing gradients for functions where optimistix is called via vmap. The errors disappear when replacing jax.vmap with an equivalent for loop. I have included a MWE bug_report.py which can switch between jax.vmap and for loops via the VMAP variable. My first impression is that the implicit solve during backprop gets passed the wrong (unbatched?) input vector.

Traceback (most recent call last):
  File ".../python/bug_report.py", line 74, in <module>
    loss, grads = loss_fn_w_grad(
  File ".../python/bug_report.py", line 56, in loss_fn
    output = batched_model(
  File ".../python/bug_report.py", line 44, in __call__
    return self.output_layer(opt_2st_vec(t))
  File ".../python/bug_report.py", line 22, in opt_2st_vec
    solution = optx.root_find(obj, solver, x0)
  File ".../venv/lib/python3.11/site-packages/optimistix/_root_find.py", line 227, in root_find
    return iterative_solve(
  File ".../venv/lib/python3.11/site-packages/optimistix/_iterate.py", line 346, in iterative_solve
    ) = adjoint.apply(_iterate, rewrite_fn, inputs, tags)
  File ".../venv/lib/python3.11/site-packages/optimistix/_adjoint.py", line 148, in apply
    return implicit_jvp(primal_fn, rewrite_fn, inputs, tags, self.linear_solver)
  File ".../venv/lib/python3.11/site-packages/optimistix/_ad.py", line 72, in implicit_jvp
    root, residual = _implicit_impl(fn_primal, fn_rewrite, inputs, tags, linear_solver)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: pytree does not match out_structure

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File ".../python/bug_report.py", line 74, in <module>
    loss, grads = loss_fn_w_grad(
                  ^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/equinox/_ad.py", line 79, in __call__
    return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 413, in _vprim_transpose
    return transpose(cts, *inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 211, in _wrapper
    cts = rule(inputs, cts_out)
          ^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/lineax/_solve.py", line 272, in _linear_solve_transpose
    cts_vector, _, _ = eqxi.filter_primitive_bind(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 264, in filter_primitive_bind
    flat_out = prim.bind(*dynamic, treedef=treedef, static=static, flatten=flatten)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 299, in batch_rule
    out = _vprim_p.bind(
          ^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 337, in _vprim_abstract_eval
    outs = abstract_eval(*inputs, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/equinox/internal/_primitive.py", line 147, in _wrapper
    out = rule(*args)
          ^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/lineax/_solve.py", line 115, in _linear_solve_abstract_eval
    out = eqx.filter_eval_shape(
          ^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/lineax/_solve.py", line 86, in _linear_solve_impl
    out = solver.compute(state, vector, options)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/lineax/_solve.py", line 632, in compute
    solution, result, _ = solver.compute(state, vector, options)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/lineax/_solver/lu.py", line 62, in compute
    vector = ravel_vector(vector, packed_structures)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.11/site-packages/lineax/_solver/misc.py", line 84, in ravel_vector
    raise ValueError("pytree does not match out_structure")
ValueError: pytree does not match out_structure
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

MWE:

import equinox as eqx
import jax.random as jr
import jax.numpy as jnp
import jax
import optimistix as optx

jax.config.update("jax_enable_x64", True)

VMAP = True

def rf(x, args, g):
    c = 1 - x[0] - x[1]
    f = -x[0] * jnp.exp(-g) - x[1]
    return f, c

def opt_2st_vec(g):
    x0 = (1 / 2, 1 / 2)
    obj = eqx.Partial(rf, g=g.squeeze())
    solver = optx.Newton(atol=1e-8, rtol=1e-8)
    solution = optx.root_find(obj, solver, x0)
    return jnp.expand_dims(solution.value[1], 0)

class Model(eqx.Module):
    input_layer: eqx.nn.Linear
    output_layer: eqx.nn.Linear

    def __init__(
        self,
        n_inputs,
        key,
    ):
        self.input_layer = eqx.nn.Linear(
            in_features=n_inputs, out_features=1, use_bias=False, key=key
        )
        self.output_layer = eqx.nn.Linear(
            in_features=1, out_features=1,  use_bias=True, key=key
        )

    def __call__(self, inputs):
        t = self.input_layer(inputs)
        return self.output_layer(opt_2st_vec(t))

def loss_fn(
    params,
    static,
    inputs_folding,
    target,
):
    model = eqx.combine(params, static)
    if VMAP:
        batched_model = jax.vmap(model)
        output = batched_model(
            inputs_folding,
        )
    else:
        output = jnp.array([
            model(inputs_folding[i])
            for i in range(inputs_folding.shape[0])
        ])
    loss = jnp.mean(jnp.abs(target - output[:, 0]))
    return loss

inputs = jr.uniform(jr.PRNGKey(0), (128, 10))
target = jr.uniform(jr.PRNGKey(0), (128,))

model = Model(inputs.shape[1], jr.PRNGKey(0))

params, static = eqx.partition(model, eqx.is_array)
loss_fn_w_grad = eqx.filter_value_and_grad(loss_fn)
loss, grads = loss_fn_w_grad(
    params,
    static,
    inputs,
    target,
)

package versions:

equinox==0.11.3
jax==0.4.25
jaxlib==0.4.25
lineax==0.0.4
optimistix==0.0.6
patrick-kidger commented 8 months ago

Hmm, I think I agree that sounds like a plausible root cause.

I'm still looking at this, but FWIW I've managed to reduce it to this MWE. Curiously, the type of x0 seems to affect whether a crash is generated. Right now I'm not sure why that should be!

import jax.random as jr
import jax.numpy as jnp
import jax
import optimistix as optx

jax.config.update("jax_enable_x64", True)

CRASH = True

def rf(x, g):
    return x[0], x[1] - g

def opt_2st_vec(g):
    if CRASH:
        x0 = (0.5, 0.5)
    else:
        x0 = jnp.array([0.5, 0.5])
    solver = optx.Newton(atol=1e-8, rtol=1e-8)
    solution = optx.root_find(rf, solver, x0, args=g)
    return solution.value[0]

def loss_fn(x):
    return jnp.sum(jax.vmap(opt_2st_vec)(x))

x = jr.uniform(jr.key(0), (128,))
jax.grad(loss_fn)(x)

I'll keep poking at this, but let me know if you find anything sooner than that.

patrick-kidger commented 8 months ago

Okay, got it! Looks like grad-of-vmap-of-<a linear_solve that we only some of the outputs from> threaded the needle to hit a case we didn't handle correctly.

I've opened https://github.com/patrick-kidger/equinox/pull/671 and https://github.com/patrick-kidger/lineax/pull/84 to fix this. (Although the Lineax CI will fail as it can't see the updated Equinox PR.) I'm hoping to do new Equinox and Lineax releases, including these fixes, in the next few days.

FFroehlich commented 8 months ago

Fantastic thanks for the quick fix & workaround.

johannahaffner commented 8 months ago

Hi @patrick-kidger and @FFroehlich,

I might have a related issue. It persists even with the fixes in equinox@dev and lineax@vprim_transpose_symbolic_zeros.

I'm vmapping a nonlinear solve (parameter estimation for ODEs across many individuals, each with their own parameter set).

I get ValueError: Unexpected batch tracer. This operation cannot be vmap'd., raised by _cannot_batch in equinox/internal/_nontraceable.py, which calls jax.interpreters.batching. (The whole thing is very long.)

The error goes away if I use a for-loop, and it also goes away with a nonlinear solver that does not use gradients (Nelder-Mead).

I'm working on an MWE, starting by adapting yours from above, @patrick-kidger.

For added context: I have a nested hierarchical model composed of equinox modules, and I now want to optimize the final layer (population level) to leverage jax' SPMD capabilities.

johannahaffner commented 8 months ago

Here comes the MWE.

import jax.random as jr
import jax.numpy as jnp
import jax
import optimistix as optx
import equinox as eqx
import diffrax as dfx

jax.config.update("jax_enable_x64", True)

GRAD = True
VMAP = True

def dydt(t, y, args):
    k = args
    return -k * y

class Individual(eqx.Module):
    term: dfx.ODETerm
    solver: dfx.Tsit5
    y0: float
    t0: int
    t1: int
    dt0: int
    saveat: dfx.SaveAt

    def __init__(self, ode_system, y0):
        self.term = dfx.ODETerm(ode_system)
        self.solver = dfx.Tsit5()

        self.y0 = y0
        self.t0 = 0
        self.t1 = 10
        self.dt0 = 0.01
        self.saveat = dfx.SaveAt(ts=jnp.arange(self.t0, self.t1, self.dt0))

    def simulate(self, args):
        sol = dfx.diffeqsolve(
            self.term, 
            self.solver, 
            self.t0, 
            self.t1, 
            self.dt0, 
            self.y0, 
            args=args, 
            saveat=self.saveat,
            adjoint=dfx.DirectAdjoint(),
        )
        return sol.ys

    def estimate_param(self, initial_param, ydata, solver):
        args = (self.simulate, ydata)

        def residuals(param, args):
            model, ydata = args
            yfit = model(param)
            res = ydata - yfit
            return res

        sol = optx.least_squares(
            residuals,
            solver, 
            initial_param,
            args=args,
        )
        return sol.value

m = Individual(dydt, 10.)

def generate_data(individual_model):  # Noise-free
    k0s = (0.3, 0.5, 0.7)  # Vary parameters
    ydata = []
    for k0 in k0s:
        y = individual_model.simulate(k0)
        ydata.append(y)
    return jnp.array(ydata)

data = generate_data(m)
initial_k0 = 0.5  # Starting point for all 

def run(initial_param, individual_model, individual_data):
    if GRAD:
        solver = optx.LevenbergMarquardt(rtol=1e-07, atol=1e-07)
    else: 
        solver = optx.NelderMead(rtol=1e-07, atol=1e-07)
    if VMAP:
        get_params = jax.vmap(individual_model.estimate_param, in_axes=(None, 0, None))
        params = get_params(initial_param, individual_data, solver)
    else:
        params = [individual_model.estimate_param(initial_param, y, solver) for y in individual_data]
    return params

params = run(initial_k0, m, data)
params

And this is how it behaves (with equinox@dev and lineax@vprim_transpose_symbolic_zeros).

If (GRAD and VMAP): ValueError: Unexpected batch tracer. This operation cannot be vmap'd. Works for the other three combinations.

johannahaffner commented 8 months ago

A few post scriptums:

  1. eqx.filter_vmap does not make a difference.

  2. I noticed that your example also uses a combination of vmap/diffeqsolve/least squares. Since you batch inside the residuals function, this means you have a composition grad(vmap)), which works. I have vmap(grad), which does not. (Tried running it with and without jax.config.update('jax_disable_jit', True).)

  3. Replacing Individual.simulate(...) with Individual.__call__(...) and defining a function estimate_param outside of the Individual class also does not change things. I had been wondering if it is a problem that things are happening inside of bound methods.

  4. The same also happens with BFGS. GradientDescent and NonlinearCD do not converge, so I can't judge them using this MWE. However, it does not happen with GaussNewton.

  5. I don't think it is in lineax. Gauss Newton works with QR, which is the default for Levenberg-Marquardt.

patrick-kidger commented 7 months ago

Thank you for the issue! This was a fairly tricky one.

Ultimately I think this is a sort-of bug (or at least a questionable design decision) in jax.checkpoint. This seems to be something we can work around, however, in Equinox. As such I've opened https://github.com/patrick-kidger/equinox/pull/694 to address this.

Can you give that branch a go on your actual (non-MWE) problem, and let me know if that fixes things? If so then I'll merge it.

johannahaffner commented 7 months ago

It works!

Thank you so much for taking a look at this, even during the Easter holidays. It is very much appreciated!

I want to add that I am new to the ecosystem and enjoy it very much, it is so well thought-through and documented. I hope I can start contributing something other than questions as I get to know it better :)

patrick-kidger commented 7 months ago

Awesome stuff, I'm glad to hear it! I hope you enjoy using the ecosystem. :)

On this basis I've just merged the fix, so it will appear in the next release of Equinox.