sammccallum / reversible

JAX implementation of the Reversible Solver method.
https://arxiv.org/abs/2410.11648
0 stars 1 forks source link

Efficient, Accurate and Stable Gradients for Neural ODEs

Overview

This repository contains a JAX implementation of the Reversible Solver method introduced here.

We present a general class of algebraically reversible solvers that allows any explicit numerical solver to be made reversible. This class of reversible solvers produce exact, memory-efficient gradients and are:

Example

Simple Neural ODE example. We wrap the Dormand-Prince 5/4 (Dopri5) solver in a Reversible class.

If the solve_forward function appears in any jax.grad region, the memory-efficient backpropagation algorithm through the solve is automatically used.

import equinox as eqx
import jax.numpy as jnp
import jax.random as jr

from reversible.reversible_solver import Reversible
from reversible.solver_step import Dopri5
from reversible.vector_field import AbstractVectorField

# Simple neural vector field
class VectorField(AbstractVectorField):
    layers: list

    def __init__(self, key):
        key1, key2 = jr.split(key, 2)
        self.layers = [
            eqx.nn.Linear(1, 10, use_bias=True, key=key1),
            jnp.tanh,
            eqx.nn.Linear(10, 1, use_bias=True, key=key2),
        ]

    def __call__(self, t, y):
        for layer in self.layers:
            y = layer(y)
        return y

# Setup vector field
key = jr.PRNGKey(0)
vf = VectorField(key)

# Reversible Dopri5
solver = Reversible(l=0.999, solver=Dopri5())

# Solve over [0, T]
h = 0.01
T = 1
y0 = jnp.asarray(1.0)[None]  # shape (1,)
y1 = solver.solve_forward(vf, y0, h, T)

Experiments

All code to reproduce the experiments presented in the paper can be found in the experiments folder.

Installation

To install the reversible package, clone the repository and run:

pip install -e reversible