nathanaelbosch / parallel-in-time-ode-filters

Parallel-in-Time ODE Filters in Jax
6 stars 0 forks source link

Reverse Mode Differentiation #1

Open adam-hartshorne opened 11 months ago

adam-hartshorne commented 11 months ago

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.

I would like to test your library within an existing framework that I have. This involves learning the parameters of a GP which controls a vector flow field and thus I need reverse-mode differentiability in order to optimise a loss function to learn these parameters.

I don't know if you are aware that there are undocumented functions within the Equinox library for JAX which handle this,

https://github.com/patrick-kidger/equinox/tree/main/equinox/internal/_loop

nathanaelbosch commented 11 months ago

I have not tried tested the code with autodiff. Would you have a minimal example that I could try? If there is some easy way to get this to work, I would definitely be up for updating the code to support this.

adam-hartshorne commented 11 months ago

Solving a simple NODE like this. (obviously this isn't using the PODE solver properly by optimising NLL).

There are plenty of examples at provided with Diffrax and ProbDiffeq that require similar optimisation, which I presume your library could be a drop in replacement.

import jax
import jax.numpy as jnp
import jax.nn as jnn
import jax.random as jrandom
import equinox as eqx
import matplotlib.pyplot as plt
import optax
from pof.solver import solve, sequential_eks_solve

ts = jnp.linspace(0, 1.0, num=100)
ys = jnp.sin(5 * jnp.pi * ts)

class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=data_size,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            key=key,
        )

    def __call__(self, t, y):
        return self.mlp(y)

class NeuralODE(eqx.Module):
    func: Func

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(data_size, width_size, depth, key=key)

    def __call__(self, y0):

        ts_par = jnp.linspace(0, 1.0, 100)
        ys_par, info_par = solve(f=self.func, y0=y0, ts=ts_par, order=3, init="constant")
        mean, cov = ys_par
        return mean

data_size = 1
steps = 500
print_every = 10
width_size = 64
depth = 2
lr = 1e-3

key = jrandom.PRNGKey(42)
data_key, model_key, loader_key = jrandom.split(key, 3)
model = NeuralODE(data_size, width_size, depth, key=model_key)

@eqx.filter_value_and_grad
def grad_loss(model, x0):
    y_pred = model(x0)
    return jnp.mean((ys - y_pred) ** 2)

@eqx.filter_jit
def make_step(x0, model, opt_state):
    loss, grads = grad_loss(model, x0)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

optim = optax.adabelief(lr)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
x0 = jnp.array([0.0])

for step in range(steps):
    loss, model, opt_state = make_step(x0, model, opt_state)
    if (step % print_every) == 0 or step == steps - 1:
        print(f"Step: {step}, Loss: {loss}")
nathanaelbosch commented 11 months ago

Thanks for the example! So this seems to be a fundamental issue with lax.while_loop. Right now I don't see an easy way to get rid of it while still doing what the code is supposed to do. I would also prefer not to build on some internal undocumented functionality from equinox, as it is not covered by semantic versioning so this code might change with each new release and thereby break this repository.

Maybe the best option here for you to test the parallel-in-time solvers in your specific usecase would be to fork the repository and replace the while loop with equinox's. Or if you see some other way to add reverse-diff support here, let me know and we can try and figure out how get it implemented.

adam-hartshorne commented 11 months ago

1) The undocumented functions in equinox is very mature as it forms the basis of the Diffrax ODE Solver library that is widely used, so it is unlikely to massively change and is designed to replace lax.while_loop.

2) Have you looked at how ProbDiffEq achieve this, as that library enables auto-diff and solves Probabilistic ODEs.

3) I think that for wider spread adaption you will need reverse diff support, as solving Neural ODEs are widely used in Diffusion Models, Normalising Flows etc.

adam-hartshorne commented 11 months ago

Another related suggestion would be to look into using https://github.com/wilson-labs/cola

It is trivial to drop in replace for base JAX operators, but this enables highly efficient linear ops via multiple dispatch and lazy evaluation. It improves speed and memory efficiency significantly.

nathanaelbosch commented 11 months ago

I agree with your points! Supporting autodiff would definitely be good for a probabilistic solver library. But this repository here is first and foremost meant to support our publication on the matter and to make our experiments reproducible - it is not meant as a user-facing library with lots of features. This job is better done by libraries like probdiffeq in jax, or my ProbNumDiffEq.jl package in Julia, both of which are actively maintained, well-tested and documented. I hope to make the parallel-in-time functionality available in ProbNumDiffEq.jl in the future, but I cannot give an ETA on this.