patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.37k stars 124 forks source link

RecursiveCheckpointAdjoint not working for two-level minimisation #465

Open ddrous opened 1 month ago

ddrous commented 1 month ago

Hi all, I've edited the introductory Neural ODE example to highlight a problem I'm facing with two-level optimisation: first (outer) level wrt the model, and second (inner) level wrt a parameter alpha. JAX throws a JaxStackTraceBeforeTransformation error if I use RecursiveCheckpointAdjoint, but everything runs if I use DirectAdjoint instead. In line with the recommendations in the documentation, I'd love to use the former adjoint rule. Please help, Thanks.

import equinox as eqx
import diffrax
import jax
import jax.numpy as jnp

data_size=2

class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self):
        self.mlp = eqx.nn.MLP(
            in_size=data_size+1,
            out_size=data_size,
            width_size=4,
            depth=2,
            activation=jax.nn.softplus,
            key=jax.random.PRNGKey(0),
        )

    def __call__(self, t, y, args):
        alpha = args[0]
        y = jnp.concatenate([y, alpha])
        return self.mlp(y)

class NeuralODE(eqx.Module):
    func: Func

    def __init__(self):
        self.func = Func()

    def __call__(self, ts, y0, alpha):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=y0,
            args=(alpha,),
            # adjoint=diffrax.DirectAdjoint(),               ## works fine ! 🎉
            adjoint=diffrax.RecursiveCheckpointAdjoint(),    ## throws a JaxStackTraceBeforeTransformation 😢
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(ts=ts),
        )
        return solution.ys

def loss_fn(model, alpha):
    ts = jnp.linspace(0, 1, 100)
    y0 = jnp.zeros(data_size)
    return jnp.mean(model(ts, y0, alpha) ** 2)

def inner_step(model, alpha):
    alpha_grad = eqx.filter_grad(lambda alpha, model: loss_fn(model, alpha))(alpha, model)
    return jnp.mean(alpha_grad)

def outer_step(model, alpha):
    model_grad = eqx.filter_grad(inner_step)(model, alpha)
    return model_grad

model = NeuralODE()
alpha = jnp.array([1.])

## Run the outer step
outer_step(model, alpha)
patrick-kidger commented 1 month ago

You're actually bumping into something that I think is a bit of an open reseach problem. :) Namely, how to do second-order autodifferentiation whilst using checkpointing! In particular what you're seeing here is that the backward pass for RecursiveCheckpointAdjoint is not itself reverse-mode autodifferentiable.

I do note that alpha appears to be a scalar. I've not thought through every detail, but for such cases it usually more efficient to use jax.jvp to perform forward-mode autodifferentiation instead. Typically the goal is to frame the computation as a jvp-of-grad-of-loss. (Such 'forward over reverse' is usually most efficient overall.) This may allow you to sidestep this problem.

Failing that, then ysing DirectAdjoint is probably the best option available here.

ddrous commented 1 month ago

Thank you @patrick-kidger It helps to know what the real problem is. Looking forward to any research/development on this in the future.

Using JVPs is not really an option for me since my parameters are themselves neural nets (I turned alpha into a scalar just for the purpose of a MWE). So looks like I'm gonna have to use Directdjoint() even-though I can barely handle its memory requirements (this after tweaking max_steps).