patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.43k stars 129 forks source link

[Question] Neural CDE for regression #519

Open suargi opened 1 week ago

suargi commented 1 week ago

Description

I would like to create a neural CDE for regression. For that, I have taken the example from neural CDE for classification and adapted using the content from neural ODE for regression.

I am encountering some issues which I do not know how to solve. I would appreciate if someone could point me in the right direction. Thank you!

Code

import time

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import matplotlib
import matplotlib.pyplot as plt
import optax

class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=data_size,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.tanh,
            final_activation=jnn.tanh,
            key=key,
        )

    def __call__(self, t, y, args):
        return self.mlp(y)

class NeuralCDE(eqx.Module):
    initial: eqx.nn.MLP
    func: Func

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        ikey, fkey, lkey = jr.split(key, 3)
        self.initial = eqx.nn.MLP(in_size=data_size, out_size=data_size, width_size=width_size, depth=depth, key=ikey)
        self.func = Func(data_size, width_size, depth, key=fkey)

    def __call__(self, ts, coeffs, evolving_out=False):
        # Each sample of data consists of some timestamps `ts`, and some `coeffs`
        # parameterising a control path. These are used to produce a continuous-time
        # input path `control`.
        control = diffrax.CubicInterpolation(ts, coeffs)
        term = diffrax.ControlTerm(self.func, control).to_ode()
        solver = diffrax.Tsit5()
        dt0 = ts[1] - ts[0]
        y0 = self.initial(control.evaluate(ts[0]))
        solution = diffrax.diffeqsolve(
            term,
            solver,
            ts[0],
            ts[-1],
            dt0,
            y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(ts=ts),
        )
        return solution.ys

# ============================================================================
def _get_data(ts, *, key):
    y0 = jr.uniform(key, (2,), minval=-0.6, maxval=1)

    def f(t, y, args):
        x = y / (1 + y)
        return jnp.stack([x[1], -x[0]], axis=-1)

    solver = diffrax.Tsit5()
    dt0 = 0.1
    saveat = diffrax.SaveAt(ts=ts)
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(f), solver, ts[0], ts[-1], dt0, y0, saveat=saveat
    )
    ys = sol.ys
    return ys

def get_data(dataset_size, *, key):
    length = 100
    ts = jnp.linspace(0, 10, length)
    key = jr.split(key, dataset_size)
    ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)
    ts_broadcasted = jnp.broadcast_to(ts, (dataset_size, length))
    ys = jnp.concatenate([ts_broadcasted[:, :, None], ys], axis=-1) # time is a channel
    coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts_broadcasted, ys)
    return ts_broadcasted, ys, coeffs

# ============================================================================

def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jr.permutation(key, indices)
        (key,) = jr.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

def main(
    dataset_size=256,
    batch_size=32,
    lr_strategy=(3e-3, 3e-3),
    steps_strategy=(500, 500),
    length_strategy=(0.1, 1),
    width_size=64,
    depth=2,
    seed=5678,
    plot=True,
    print_every=100,
):
    key = jr.PRNGKey(seed)
    data_key, model_key, loader_key = jr.split(key, 3)

    ts, ys, coeffs = get_data(dataset_size, key=data_key)
    _, length_size, data_size = ys.shape

    model = NeuralCDE(data_size, width_size, depth, key=model_key)

    # Training loop like normal.
    #
    # Only thing to notice is that up until step 500 we train on only the first 10% of
    # each time series. This is a standard trick to avoid getting caught in a local
    # minimum.

    @eqx.filter_jit # value_and_grad
    def loss(model, ti, yi, coeff_i):
        y_pred = jax.vmap(model, in_axes=(None, 0))(ti[0, :], coeff_i)
        # MSE without time column
        return jnp.mean((yi[:, :, 1:] - y_pred[:, :, 1:]) ** 2)

    grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)

    @eqx.filter_jit
    def make_step(data_i, model, opt_state):
        ti, yi, *coeff_i = data_i
        loss, grads = grad_loss(model, ti, yi, coeff_i)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):
        optim = optax.adam(lr)
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
        _ts = ts[:, : int(length_size * length)]
        _ys = ys[:, : int(length_size * length)]
        _coeffs = tuple(arr[:, :int(length_size * length) - 1] for arr in coeffs)
        for step, data_i in zip(
            range(steps), dataloader((_ts, _ys) + _coeffs, batch_size, key=loader_key)
        ):
            start = time.time()
            loss, model, opt_state = make_step(data_i, model, opt_state)
            end = time.time()
            if (step % print_every) == 0 or step == steps - 1:
                print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")

    if plot:
        plt.plot(ts, ys[0, :, 0], c="dodgerblue", label="Real")
        plt.plot(ts, ys[0, :, 1], c="dodgerblue")
        sample_coeffs = tuple(c[-1] for c in coeffs)
        pred = model(ts, sample_coeffs, evolving_out=True)
        plt.plot(ts, pred[:, 0], c="crimson", label="Model")
        plt.plot(ts, pred[:, 0], c="crimson")
        plt.legend()
        plt.tight_layout()
        plt.savefig("neural_ode.png")
        plt.show()

    return ts, ys, model

ts, ys, model = main()

Error

The error originates at line

return jnp.mean((yi[:, :, 1:] - y_pred[:, :, 1:]) ** 2)

The error message is quite large to write it down here. To replicate the error, please run the code above. Note: My intention is to compute the MSE between the predicted values and the true values. The variables y and y_pred contain the time series values on the first column. Therefore, for the MSE I only use the last two columns.

Specifications

jax 0.4.35 jaxlib 0.4.35 jaxtyping 0.2.34 diffrax 0.6.0 equinox 0.11.8 numpy 2.1.2 optax 0.2.3

lockwo commented 1 week ago

You have grad_loss = eqx.filter_value_and_grad(loss, has_aux=True) but you don't actually return any auxiliary variables. Setting that to false yields:

Step: 0, Loss: 0.17582178115844727, Computation time: 13.726179122924805
Step: 100, Loss: 0.012010098434984684, Computation time: 0.04791116714477539
Step: 200, Loss: 0.01128536369651556, Computation time: 0.06833648681640625
Step: 300, Loss: 0.006681683007627726, Computation time: 0.03933405876159668
Step: 400, Loss: 0.008453472517430782, Computation time: 0.034162044525146484
suargi commented 1 week ago

Thank you, that solved the issue.

I have tried different hyperparameter combinations (num. epochs, learning rate, num. layers, etc) but I cannot get as accurate results as with the Neural ODE. I am wondering if there is some problem with my code. Would be possible for you to take a look and verify that my implementation is correct? Thank you.

Updated code:

import time

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import matplotlib
import matplotlib.pyplot as plt
import optax

class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=data_size,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.tanh,
            final_activation=jnn.tanh,
            key=key,
        )

    def __call__(self, t, y, args):
        return self.mlp(y)

class NeuralCDE(eqx.Module):
    initial: eqx.nn.MLP
    func: Func

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        ikey, fkey, lkey = jr.split(key, 3)
        self.initial = eqx.nn.MLP(in_size=data_size, out_size=data_size, width_size=width_size, depth=depth, key=ikey)
        self.func = Func(data_size, width_size, depth, key=fkey)

    def __call__(self, ts, coeffs, evolving_out=False):
        # Each sample of data consists of some timestamps `ts`, and some `coeffs`
        # parameterising a control path. These are used to produce a continuous-time
        # input path `control`.
        control = diffrax.CubicInterpolation(ts, coeffs)
        term = diffrax.ControlTerm(self.func, control).to_ode()
        solver = diffrax.Tsit5()
        dt0 = ts[1] - ts[0]
        y0 = self.initial(control.evaluate(ts[0]))
        solution = diffrax.diffeqsolve(
            term,
            solver,
            ts[0],
            ts[-1],
            dt0,
            y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(ts=ts),
        )
        return solution.ys

# ============================================================================
def _get_data(ts, *, key):
    y0 = jr.uniform(key, (2,), minval=-0.6, maxval=1)

    def f(t, y, args):
        x = y / (1 + y)
        return jnp.stack([x[1], -x[0]], axis=-1)

    solver = diffrax.Tsit5()
    dt0 = 0.1
    saveat = diffrax.SaveAt(ts=ts)
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(f), solver, ts[0], ts[-1], dt0, y0, saveat=saveat
    )
    ys = sol.ys
    return ys

def get_data(dataset_size, *, key):
    length = 100
    ts = jnp.linspace(0, 10, length)
    key = jr.split(key, dataset_size)
    ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)
    ts_broadcasted = jnp.broadcast_to(ts, (dataset_size, length))
    ys = jnp.concatenate([ts_broadcasted[:, :, None], ys], axis=-1) # time is a channel
    coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts_broadcasted, ys)
    return ts_broadcasted, ys, coeffs

# ============================================================================

def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jr.permutation(key, indices)
        (key,) = jr.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

def main(
    dataset_size=256,
    batch_size=64,
    lr_strategy=(3e-3, 3e-3),
    steps_strategy=(500, 500),
    length_strategy=(1, 1),
    width_size=64,
    depth=2,
    seed=5678,
    plot=True,
    print_every=100,
):
    key = jr.PRNGKey(seed)
    data_key, model_key, loader_key = jr.split(key, 3)

    ts, ys, coeffs = get_data(dataset_size, key=data_key)
    _, length_size, data_size = ys.shape

    model = NeuralCDE(data_size, width_size, depth, key=model_key)

    # Training loop like normal.
    #
    # Only thing to notice is that up until step 500 we train on only the first 10% of
    # each time series. This is a standard trick to avoid getting caught in a local
    # minimum.

    @eqx.filter_jit # value_and_grad
    def loss(model, ti, yi, coeff_i):
        y_pred = jax.vmap(model, in_axes=(None, 0))(ti[0, :], coeff_i)
        # MSE without time column
        return jnp.mean((yi[:, :, 1:] - y_pred[:, :, 1:]) ** 2)

    grad_loss = eqx.filter_value_and_grad(loss, has_aux=False)

    @eqx.filter_jit
    def make_step(data_i, model, opt_state):
        ti, yi, *coeff_i = data_i
        loss, grads = grad_loss(model, ti, yi, coeff_i)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):
        optim = optax.adam(lr)
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
        _ts = ts[:, : int(length_size * length)]
        _ys = ys[:, : int(length_size * length)]
        _coeffs = tuple(arr[:, :int(length_size * length) - 1] for arr in coeffs)
        for step, data_i in zip(
            range(steps), dataloader((_ts, _ys) + _coeffs, batch_size, key=loader_key)
        ):
            start = time.time()
            loss, model, opt_state = make_step(data_i, model, opt_state)
            end = time.time()
            if (step % print_every) == 0 or step == steps - 1:
                print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")

    if plot:
        ts = ts[0, :]
        plt.plot(ts, ys[0, :, 1], c="dodgerblue", label="Real")
        plt.plot(ts, ys[0, :, 2], c="dodgerblue")
        sample_coeffs = tuple(c[-1] for c in coeffs)
        pred = model(ts, sample_coeffs, evolving_out=True)
        plt.plot(ts, pred[:, 1], c="crimson", label="Model")
        plt.plot(ts, pred[:, 2], c="crimson")
        plt.legend()
        plt.tight_layout()
        plt.savefig("neural_ode.png")
        plt.show()

    return ts, ys, coeffs, model

ts, ys, coeffs, model = main()
lockwo commented 1 week ago

I'm probably not familiar enough with Neural CDEs to be able to diagnose issues without substantial investigation. I would recommend checking piece by piece to make sure each of the subroutines is operating as expected, e.g. by comparing to specific known solutions on small problems.