sammccallum / reversible

JAX implementation of the Reversible Solver method.
https://arxiv.org/abs/2410.11648
0 stars 1 forks source link

Solution including trajectory at intermediate times #1

Open adam-hartshorne opened 1 month ago

adam-hartshorne commented 1 month ago

As far as I can tell at the moment, the solution just returns the solution at t=T. Is it possible to return a trajectory of values at intermediate times? e.g. https://docs.kidger.site/diffrax/api/saveat/ ..and will it cause a significant slow down (as it does in Diffrax).

sammccallum commented 1 month ago

Hi Adam, thanks for the question.

Yep, you are correct. Currently the solution is returned at $t=T$ as the most common workflow for Neural ODEs is to define a loss $L(y(T))$ dependent on the terminal state.

Alternatively, if your loss function depends on the full trajectory $L(y(t))$ then it is possible to calculate the loss as an integral alongside the ODE solve (without saving the full trajectory). This was done for the CNF example in the paper and would have improved memory-efficiency over storing the trajectory.

Anyway, it would be possible to add the option to return the full trajectory. This would require tweaking the backpropagation algorithm: we would need to update the intermediate solution tangents $\bar{y}_n$ at each time-step $n$ rather than calculating from zero.

As for Diffrax, I believe the solution is interpolated over the (adaptive) time-steps taken and evaluates at the user specified times, which may explain the slow down. If the step-sizes are fixed then saving at intermediate times would just correspond to filling an array of pre-specified shape which wouldn't be expensive.

adam-hartshorne commented 1 month ago

My use case might be quite niche, but I need to know the locations along the trajectory at predefined intervals to then be used in a number of different loss functions.

I currently use a probabilistic ode solver (https://github.com/pnkraemer/probdiffeq) both because I can output the trajectory efficiently but also in some use cases I wish to model the problem to incorporate uncertainty both from the solve and the data. However, in other use cases, I don't need any UQ and am just looking for the most efficient solver.

sammccallum commented 15 hours ago

@adam-hartshorne I now have this implemented in my fork of diffrax - see the reversible-steps branch. This is still a work in progress but may be useful for your application.

Here's an example usage:

import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
from diffrax import ODETerm, Reversible, ReversibleAdjoint, SaveAt, Tsit5, diffeqsolve

# Simple neural vector field
class VectorField(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2 = jr.split(key, 2)
        self.layers = [
            eqx.nn.Linear(1, 10, use_bias=True, key=key1),
            jnp.tanh,
            eqx.nn.Linear(10, 1, use_bias=True, key=key2),
        ]

    def __call__(self, t, y, args):
        for layer in self.layers:
            y = layer(y)
        return y

if __name__ == "__main__":
    vf = VectorField(jr.PRNGKey(0))
    y0 = jnp.array([1.0])
    term = ODETerm(vf)
    solver = Reversible(Tsit5(), l=0.999)

    sol = diffeqsolve(
        term,
        solver,
        t0=0,
        t1=5,
        dt0=0.01,
        y0=y0,
        adjoint=ReversibleAdjoint(),
        saveat=SaveAt(t0=True, steps=True),
    )
    final_index = sol.stats["num_accepted_steps"]
    ys = sol.ys[: final_index + 1]

To use the Reversible solver, you can wrap any AbstractRungeKutta method in diffrax (here we used Tsit5). For memory-efficient backpropagation, we pass adjoint=ReversibleAdjoint() to diffeqsolve. We save the steps that the solver took (plus initial condition) using SaveAt(t0=True, steps=True).

Note that a gradient can be taken w.r.t. any of the ys produced by SaveAt, which now includes the trajectory. Also no interpolation is used so this should be fast! (In my testing, the ReversibleAdjoint is faster than the default checkpointed backpropagation.)