Closed jucor closed 2 months ago
What prevents you from vmapping over these parameters? That is to say, you can replicate the torchode example in the readme with something (exactly) like this
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffrax import diffeqsolve, ODETerm, Dopri5, SaveAt, PIDController
import equinox as eqx
def f(t, y, args):
return -0.5 * y
y0 = jnp.array([[1.2], [5.0]])
n_steps = 10
t_eval = jnp.stack((jnp.linspace(0, 5, n_steps), jnp.linspace(3, 4, n_steps)))
term = ODETerm(f)
solver = Dopri5()
stepsize_controller = PIDController(atol=1e-6, rtol=1e-3)
def solve(y0, ts):
saveat = SaveAt(ts=ts)
return diffeqsolve(term, solver, t0=ts[0], t1=5, dt0=0.1, y0=y0, saveat=saveat, stepsize_controller=stepsize_controller)
sol = jax.vmap(solve)(y0, t_eval)
print(sol.stats["num_steps"])
plt.title("Diffrax")
plt.plot(sol.ts[0], sol.ys[0], label="Solution 1")
plt.plot(sol.ts[1], sol.ys[1], label="Solution 2")
plt.legend()
plt.show()
where we have different ts (and different steps taken) for each trajectory (output of print is [4 3]). And they seem to solve the same set of equations.
Indeed Diffrax supports this already via vmap
:)
Wowser, this is terrific. Can't believe how simple it is, yet it works :) Thanks a lot!
Dear @patrick-kidger
Is there any way that diffrax would support using different timesteps within a batch, like https://github.com/martenlienen/torchode offers?
For context: I need to evaluate a 1-dimensional ODE on a fixed time grid but with 1000,000 different parameters, while using a Sequential Monte Carlo sampler on a ODE-based likelihood.
The best way I've found with diffrax to parallelize those calls to the ODE is to make this a single 100,000-dimensional ODE, where each dimension y[i] corresponds to one value of the parameters and for each dimension
i
,dydt[i]
only relying ony[i]
(if it were linear we'd say it's a diagonal matrix, but it's nonlinear :) ).But of course some parameters make for a more difficult ODE than others (sometimes even NaN-ning), and this thus impacts the step-size for all dimensions, making it slower and less stable.
Maybe I could just use https://github.com/martenlienen/torchode , which is in the Torch ecosystem, with would seem to promise the solution, but it's in the torch ecosystem, and my sampler is in Jax. Maybe I could mix and match, though?
Thanks for letting me know if this heteregoneous-timesteps-within-a-batch is possible, and if not, if it's on the roadmap at all!