patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.37k stars 124 forks source link

Accelerate ODE solver [What did I miss?] #466

Open zhengqigao opened 1 month ago

zhengqigao commented 1 month ago

Hi,

I am playing around with diffrax's ODE solving functionality. In a nutshell, I define a simple feedforward MLP with random initialization and benchmark the runtime of using it as the temporal derivatives of an ODE. I wrote the following code to record the run-time of ODE solving and got run-time around 3.7 sec, which seems much slower compared to other ODE solver frameworks.

I am new to jax and diffrax. What did I miss in my code implemenation?

import equinox as eqx
import jax
import diffrax
import jax.numpy as jnp
import time

class MLPeqx(eqx.Module):
    layers: list
    activation: callable = eqx.static_field()

    def __init__(self, hidden_dims):
        super().__init__()
        tmp_key = jax.random.split(jax.random.PRNGKey(0), len(hidden_dims) - 1)
        self.layers = [eqx.nn.Linear(hidden_dims[i], hidden_dims[i + 1], key=tmp_key[i]) for i in
                       range(len(hidden_dims) - 1)]
        self.activation = jax.nn.relu

    def __call__(self, x):
        for i in range(len(self.layers) - 1):
            x = self.activation(self.layers[i](x))
        x = self.layers[-1](x)
        return x

class ODEjax(eqx.Module):
    func: MLPeqx

    def __init__(self, hidden_dims):
        super().__init__()
        self.func = MLPeqx(hidden_dims)

    def __call__(self, t, y, args=None):
        return self.func(y)

def solve_ode(input_x, t, func, cfg):
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(func),
        cfg['method'],
        t0=t[0],
        t1=t[-1],
        y0=input_x,
        dt0=None,
        saveat=diffrax.SaveAt(ts=t),
        stepsize_controller=diffrax.PIDController(atol=cfg['atol'], rtol=cfg['rtol']),
    )
    return sol.ys

def run_diffrax(hidden_dims, input_x, t, num_t, cfg):
    t = jnp.linspace(t[0], t[1], num_t)
    func = ODEjax(hidden_dims)
    y = jax.vmap(solve_ode, in_axes=(0, None, None, None))(input_x, t, func, cfg)
    return y

if __name__ == '__main__':
    batch_size = 128
    hidden_dims = [100, 100, 100]
    input_x = jax.random.normal(jax.random.PRNGKey(0), (128, 100))

    start_time = time.time()
    run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
        'method': diffrax.Dopri5(),
        'atol': 1e-5,
        'rtol': 1e-5})
    end_time = time.time()

    print(f"run time = {end_time - start_time:.3f} (sec)")
lockwo commented 1 month ago

I recommending checking out Jax's docs on benchmarking https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code, the tldr for this example is that:

  1. jax compile times will be longer for first iteration (and are generally excluded in benchmarks)
  2. with async dispatch you need a block until ready

With the following code I got: 19.4 ms ± 4.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

import equinox as eqx
import jax
import diffrax
import jax.numpy as jnp
import time

class MLPeqx(eqx.Module):
    layers: list
    activation: callable = eqx.static_field()

    def __init__(self, hidden_dims):
        super().__init__()
        tmp_key = jax.random.split(jax.random.PRNGKey(0), len(hidden_dims) - 1)
        self.layers = [eqx.nn.Linear(hidden_dims[i], hidden_dims[i + 1], key=tmp_key[i]) for I in
                       range(len(hidden_dims) - 1)]
        self.activation = jax.nn.relu

    def __call__(self, x):
        for i in range(len(self.layers) - 1):
            x = self.activation(self.layers[i](x))
        x = self.layers[-1](x)
        return x

class ODEjax(eqx.Module):
    func: MLPeqx

    def __init__(self, hidden_dims):
        super().__init__()
        self.func = MLPeqx(hidden_dims)

    def __call__(self, t, y, args=None):
        return self.func(y)

def solve_ode(input_x, t, func, cfg):
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(func),
        cfg['method'],
        t0=t[0],
        t1=t[-1],
        y0=input_x,
        dt0=None,
        saveat=diffrax.SaveAt(ts=t),
        stepsize_controller=diffrax.PIDController(atol=cfg['atol'], rtol=cfg['rtol']),
    )
    return sol.ys

@eqx.filter_jit
def run_diffrax(hidden_dims, input_x, t, num_t, cfg):
    t = jnp.linspace(t[0], t[1], num_t)
    func = ODEjax(hidden_dims)
    y = jax.vmap(solve_ode, in_axes=(0, None, None, None))(input_x, t, func, cfg)
    return y

batch_size = 128
hidden_dims = [100, 100, 100]
input_x = jax.random.normal(jax.random.PRNGKey(0), (128, 100))

_ = run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': 1e-5,
    'rtol': 1e-5}).block_until_ready()

%%timeit
_ = run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': 1e-5,
    'rtol': 1e-5}).block_until_ready()
zhengqigao commented 1 month ago

Thanks so much! I have tried on my end and observed similar run-time metrics. I have another follow-up question. Say I first want to run with atol=rtol=1e-5, and later in my code I want it to run with atol=rtol=1e-4. I observe again the method run_diffrax runs slower again when changing from 1e-5 to 1e-4 because of compilation(I guess). Namely,

# first time of atol=rtol=1e-5, takes ~2secs
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': 1e-5,
    'rtol': 1e-5}).block_until_ready()

# second time of atol=rtol=1e-5, takes ~0.008secs
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': 1e-5,
    'rtol': 1e-5}).block_until_ready()

# first time of atol=rtol=1e-4, takes ~2secs
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': 1e-5,
    'rtol': 1e-5}).block_until_ready()

# second time of atol=rtol=1e-4, takes ~0.008secs
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': 1e-5,
    'rtol': 1e-5}).block_until_ready(

Is this behavior expected? I wonder if there is a way to compile only once for arbitrary atol=rtol values, and can always run around millisecond level regardless of atol and rtol.

lockwo commented 1 month ago

Yes, this behavior is expected. The python floats are getting marked as static by the filtering that happens before jit. You can make them not static by making them jax types (e.g. arrays).

start_time = time.time()
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': jnp.array(1e-5),
    'rtol': jnp.array(1e-5)}).block_until_ready()
end_time = time.time()
print(f"run time = {end_time - start_time:.3f} (sec)")

start_time = time.time()
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': jnp.array(1e-5),
    'rtol': jnp.array(1e-5)}).block_until_ready()
end_time = time.time()
print(f"run time = {end_time - start_time:.3f} (sec)")

start_time = time.time()
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': jnp.array(1e-4),
    'rtol': jnp.array(1e-5)}).block_until_ready()
end_time = time.time()
print(f"run time = {end_time - start_time:.3f} (sec)")

start_time = time.time()
run_diffrax(hidden_dims, input_x, [0.0, 1.0], 100, {
    'method': diffrax.Dopri5(),
    'atol': jnp.array(1e-4),
    'rtol': jnp.array(1e-5)}).block_until_ready()
end_time = time.time()
print(f"run time = {end_time - start_time:.3f} (sec)")
run time = 4.057 (sec)
run time = 0.016 (sec)
run time = 0.013 (sec)
run time = 0.013 (sec)