patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.45k stars 130 forks source link

Different step sizes within batch? #500

Closed jucor closed 2 months ago

jucor commented 3 months ago

Dear @patrick-kidger

Is there any way that diffrax would support using different timesteps within a batch, like https://github.com/martenlienen/torchode offers?

For context: I need to evaluate a 1-dimensional ODE on a fixed time grid but with 1000,000 different parameters, while using a Sequential Monte Carlo sampler on a ODE-based likelihood.

The best way I've found with diffrax to parallelize those calls to the ODE is to make this a single 100,000-dimensional ODE, where each dimension y[i] corresponds to one value of the parameters and for each dimension i, dydt[i] only relying on y[i] (if it were linear we'd say it's a diagonal matrix, but it's nonlinear :) ).

But of course some parameters make for a more difficult ODE than others (sometimes even NaN-ning), and this thus impacts the step-size for all dimensions, making it slower and less stable.

Maybe I could just use https://github.com/martenlienen/torchode , which is in the Torch ecosystem, with would seem to promise the solution, but it's in the torch ecosystem, and my sampler is in Jax. Maybe I could mix and match, though?

Thanks for letting me know if this heteregoneous-timesteps-within-a-batch is possible, and if not, if it's on the roadmap at all!

lockwo commented 2 months ago

What prevents you from vmapping over these parameters? That is to say, you can replicate the torchode example in the readme with something (exactly) like this

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffrax import diffeqsolve, ODETerm, Dopri5, SaveAt, PIDController
import equinox as eqx

def f(t, y, args):
    return -0.5 * y

y0 = jnp.array([[1.2], [5.0]])
n_steps = 10
t_eval = jnp.stack((jnp.linspace(0, 5, n_steps), jnp.linspace(3, 4, n_steps)))

term = ODETerm(f)
solver = Dopri5()
stepsize_controller = PIDController(atol=1e-6, rtol=1e-3)

def solve(y0, ts):
  saveat = SaveAt(ts=ts)
  return diffeqsolve(term, solver, t0=ts[0], t1=5, dt0=0.1, y0=y0, saveat=saveat, stepsize_controller=stepsize_controller)

sol = jax.vmap(solve)(y0, t_eval)

print(sol.stats["num_steps"])

plt.title("Diffrax")
plt.plot(sol.ts[0], sol.ys[0], label="Solution 1")
plt.plot(sol.ts[1], sol.ys[1], label="Solution 2")
plt.legend()
plt.show()

where we have different ts (and different steps taken) for each trajectory (output of print is [4 3]). And they seem to solve the same set of equations.

Screenshot 2024-08-24 at 11 15 31 PM Screenshot 2024-08-24 at 11 15 38 PM
patrick-kidger commented 2 months ago

Indeed Diffrax supports this already via vmap :)

jucor commented 2 months ago

Wowser, this is terrific. Can't believe how simple it is, yet it works :) Thanks a lot!