patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.32k stars 121 forks source link

Possible performance issue with implicit solvers #428

Open gautierronan opened 1 month ago

gautierronan commented 1 month ago

Hello @patrick-kidger, I was investigating the implicit solvers of diffrax for a possible integration into dynamiqs. However, I am seeing that the performance of Kvaerno3 and Kvaerno5 is highly dependent on the choice of atol and rtol. Is this expected? Or could this be a bug somewhere?

Here is a MWE, for a typical stiff ODE we want to solve:

import diffrax as dx
import jax.numpy as jnp

# simulation parameters
N = 16
alpha = 2.0
T = 1.0

# quantum operators
a = jnp.diag(jnp.sqrt(jnp.arange(1, N)), 1)
i = jnp.eye(N)
L = jnp.linalg.matrix_power(a, 4) - alpha**4 * i
Lt = L.T
LtL = Lt @ L

# vector field
def vector_field(t, y, _):
    return L @ y @ Lt - 0.5 * (LtL @ y + y @ LtL)

# initial state
y0 = jnp.zeros((N, N))
y0 = y0.at[0, 0].set(1.0)

# define solver
tsave = jnp.linspace(0.0, T, 100)
solve = lambda solver, tol: dx.diffeqsolve(
    dx.ODETerm(vector_field),
    solver,
    t0=0.0,
    t1=T,
    dt0=0.01,
    y0=y0,
    saveat=dx.SaveAt(ts=tsave),
    stepsize_controller=dx.PIDController(rtol=tol, atol=tol),
    max_steps=100_000,
    progress_meter=dx.TqdmProgressMeter(),
)

# test several explicit and implicit solvers with different tolerences
solver = dx.Tsit5()
num_steps = solve(solver, tol=1e-4).stats["num_steps"] # 4037
num_steps = solve(solver, tol=1e-5).stats["num_steps"] # 4039
num_steps = solve(solver, tol=1e-6).stats["num_steps"] # 4039

solver = dx.Kvaerno5()
num_steps = solve(solver, tol=1e-4).stats["num_steps"] # 12
num_steps = solve(solver, tol=1e-5).stats["num_steps"] # 3788
num_steps = solve(solver, tol=1e-6).stats["num_steps"] # 11902

solver = dx.Kvaerno3()
num_steps = solve(solver, tol=1e-4).stats["num_steps"] # 25
num_steps = solve(solver, tol=1e-5).stats["num_steps"] # 561
num_steps = solve(solver, tol=1e-6).stats["num_steps"] # 27031

As you can see, for Tsit5, the number of steps (and thus the solver performance) is very stable with increasing tolerences. However, for Kvaerno5 and Kvaerno3, increasing the tolerence by two orders of magnitudes makes the solver completely explode.

patrick-kidger commented 1 month ago

So atol and rtol ae used for precisely one thing in an explicit solver: controlling the time stepping of the PIDController.

They are used for two things in the implicit solvers: the time stepping of PIDController, but also the desired accuracy of the implicit solves happening in each Runge--Kutta stage.

I speculate that what you're seeing is primarily due to the second part.

f you want you can explicitly control those separately via e.g. Kvaerno4(root_finder=VeryChord(rtol=..., atol=...)). Give that a try and see?

C.f. also lines like this for how the stepsize controller tolerances are inherited:

https://github.com/patrick-kidger/diffrax/blob/0a59c9dbd34f580efb3505386f38ce9fcedb120b/diffrax/_solver/kvaerno4.py#L114

gautierronan commented 1 month ago

I figured it was actually because of single-precision. Switching to double-precision floating-points fixed the problem, and I now get that the implicit solvers are much faster than the explicit ones, as expected.

tol = 1e-6

# single-precision
print(solve(dx.Tsit5(), tol=tol).stats["num_steps"]) # 4039
print(solve(dx.Kvaerno5(), tol=tol).stats["num_steps"]) # 11902
print(solve(dx.Kvaerno3(), tol=tol).stats["num_steps"]) # 27031

# double-precision
jax.config.update("jax_enable_x64", True)
print(solve(dx.Tsit5(), tol=tol).stats["num_steps"]) # 4039
print(solve(dx.Kvaerno5(), tol=tol).stats["num_steps"]) # 24
print(solve(dx.Kvaerno3(), tol=tol).stats["num_steps"]) # 90

I didn't expect the single-precision to limit solving for these kinds of tolerences, but there we go :) Thanks for the help!