patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.41k stars 124 forks source link

Fitting ODE model with `diffeqsolve` is extremely slow using NUTS on GPU #338

Open kokbent opened 10 months ago

kokbent commented 10 months ago

So as the title says, I've been trying to fit my SIR ODE model using NUTS on GPU. However, the fit was extremely slow when compared to CPU. I'm using jax and numpyro to do the fitting. I ran this on Google colab:

CPU sample: 100%|██████████| 2000/2000 [02:15<00:00, 14.78it/s, 7 steps of size 3.16e-01. acc. prob=0.94]

GPU (had to interrupt because it's too slow) warmup: 1%| | 16/2000 [05:11<10:43:38, 19.47s/it, 1 steps of size 2.14e-04. acc. prob=0.58]

This is not an issue specific to diffrax, I had the same problem using odeint as my ODE solver too. I've searched through the internet, and seems like similar issue (but odeint) was reported in JAX: Gradients with odeint slow on GPU #5006. According to one of the reply: it seems like the tight loop structure in odeint is not XLA GPU friendly. Given that I have seen similar issue when using diffeqsolve, I guess that it also uses similar technique and suffer from similar issue? The question then is, is there any possible way to circumvent the problem within the diffrax package, perhaps another type of implementation?


Here's the code I use:

import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from jax import random
import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5

numpyro.set_platform("cpu")

def sir_ode(state, _, parameters):
    # Unpack state
    s, i, r = state
    beta, gamma = parameters
    population = s + i + r

    # Compute flows
    ds_to_i = beta * s * i / population
    di_to_r = gamma * i

    # Compute derivatives
    ds = -ds_to_i
    di = ds_to_i - di_to_r
    dr = di_to_r

    return (ds, di, dr)  # jnp.stack([ds, di, dr])

# Parameters
rng = np.random.default_rng(seed=867530)
beta = 1.5 / 4.5
gamma = 1.0 / 4.5
population = 10000
initial_infections = 1.0

initial_state = (
    population - initial_infections,  # s
    initial_infections,  # i
    0, # r
)

# Solve ODE
term = ODETerm(lambda t, state, parameters: sir_ode(state, t, parameters))
solver = Tsit5()
t0 = 0.0
t1 = 100.0
dt0 = 0.1
times = jnp.linspace(t0, t1, 101)
saveat = SaveAt(ts=times)

def des(initial_state, args):
    solution = diffeqsolve(
        term,
        solver,
        t0,
        t1,
        dt0,
        initial_state,
        args=args,
        saveat=saveat,
    )
    return solution

sol = des(initial_state, [beta, gamma])

# Generate incidence sample
rng = np.random.default_rng(seed=8675309)
incidence = -np.diff(sol.ys[0], axis=0)
incidence_sample = rng.poisson(incidence)

# Sampling model
def sir(times, incidence):
    # Parameters
    initial_infections = numpyro.sample("initial_infections", dist.Exponential(1.0))
    beta = numpyro.sample("beta", dist.Exponential(1.0))
    gamma = numpyro.sample("gamma", dist.Exponential(1.0))

    initial_state = (
        population - initial_infections,  # s
        initial_infections,  # i
        0,
    )  # r

    # Integrate the model
    solution = des(initial_state, [beta, gamma])
    model_incidence = -jnp.diff(solution.ys[0], axis=0)

    # Observed incidence
    numpyro.sample("incidence", dist.Poisson(model_incidence), obs=incidence)

rng_key = random.PRNGKey(8811)
nuts_kernel = NUTS(sir, dense_mass=True)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc.run(rng_key, times, incidence_sample)
patrick-kidger commented 10 months ago

The first thing that jumps out is that you don't appear to be explicitly JIT'ing your computation. Diffrax already does this for you internally for the most part, but even so best practice is to put an equinox.filter_jit on des.

The second is that it looks like beta and gamma might be Python floats rather than JAX arrays, in which case I suspect things are recompiling every time. Make them NumPy or JAX arrays. (When using equinox.filter_jit, the rule is that things will recompile if a JAX/NumPy array changes shape or dtype, and if anything else changes in any way at all.)

kokbent commented 10 months ago

Hi Patrick, thanks for the response. I've jitted my des function as you suggested. For the beta and gamma, making a jax array in the first part of the code doesn't seem to have much effect (they are only used to generate a random sample). Within the sampling model sir(), it's handled by numpyro and i believe all the sampled parameters should be in some form of JAX traceable arrays. And the MCMC is still very slow. I probably should also put the issue to numpyro.

patrick-kidger commented 10 months ago

You can double-check whether recompilation is happening with equinox.debug.assert_max_traces, by the way.