patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.44k stars 130 forks source link

Evaluate a batch of dense ODE solutions #373

Closed jnibauer closed 9 months ago

jnibauer commented 9 months ago

Hey @patrick-kidger, I've been using diffrax for (among many reasons) the support of dense solutions. One issue I've been running into is how to evaluate a batched dense solution. That is, if I solve the same ODE for many different initial conditions using vmap, with dense = True, does diffrax support the evaluation of a batched dense solution?

Here's a simple example to demonstrate:

import jax
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
from diffrax import diffeqsolve, ODETerm,Dopri5, PIDController, DirectAdjoint, SaveAt

G =  4.3e-6 #kpc km^2 s^-2 M⊙^-1
Mdisk = 1.0e10
a = 3.0
b = 0.25
rc = .5

@jax.jit
def pot0(q,t):
    R = jnp.sqrt(jnp.sum(q[:2]**2))
    return - G*Mdisk/jnp.sqrt( R**2 + (a + jnp.sqrt(b**2 + q[-1]**2))**2 )

@jax.jit
def velocity_acceleration(t,qp,args):
    q0, p0 = qp[0:3], qp[3:6]
    acc = -jax.grad(pot0)(q0,t)
    return jnp.hstack([p0,acc])

@jax.jit
def integrator_run(qp0,t0,t1,ts,eps):
    term = ODETerm(velocity_acceleration)
    solver = Dopri5(scan_kind="bounded")
    saveat = SaveAt(t0=False, t1=True, ts=None, dense=True)
    rtol: float = 1e-7 
    atol: float = 1e-7 
    dt0=None
    stepsize_controller = PIDController(rtol=rtol, atol=atol, dtmin=0.05, force_dtmin=True)
    max_steps: int = 16**4 
    solution = diffeqsolve(
            terms=term,
            solver=solver,
            t0=t0,
            t1=t1,
            y0=qp0,
            dt0=dt0,
            saveat=saveat,
            stepsize_controller=stepsize_controller,
            discrete_terminating_event=None,
            max_steps=max_steps,
            args=eps,
            adjoint=DirectAdjoint()
        )
    return solution

# Single ODE solve
q0p0 = jnp.array([20.,3.,9,8.,-30.,4.])
sol = integrator_run(q0p0,0.0,20.0,None,0.0)

# Dense solution can be evaluated as
sol.evaluate(0.5)

# Batched ODE solve
key = jax.random.PRNGKey(84530)
q0p0_batch = jax.random.normal(key,shape=(100,6))
sol_batch = jax.vmap(integrator_run,in_axes=(0,None,None,None,None))(q0p0_batch,0.0,20.0,None,0.0)

The behavior I would like is if sol_batch.evaluate(0.5) returns an array of length 100 x 6, representing the interpolated solution to the initial value problem at t=0.5 for the set of 100 initial conditions. Instead, I get the broadcasting error TypeError: lt got incompatible shapes for broadcasting: (100,), (65537,). This is likely the intended behavior, though I wanted to check if there is another way to evaluate a (batched) dense solution in this case. Thanks!

patrick-kidger commented 9 months ago

Yup, you absolutely can. The trick is either to call .evaluate either (a) whilst still within the jax.vmap, or (b) to pass sol_batch back through a jax.vmap decorator: jax.vmap(lambda s: s.evaluate(0.5))(sol_batch).

I hope that helps!

jnibauer commented 9 months ago

Perfect! Thanks!