Open gautierronan opened 1 month ago
So atol
and rtol
ae used for precisely one thing in an explicit solver: controlling the time stepping of the PIDController
.
They are used for two things in the implicit solvers: the time stepping of PIDController
, but also the desired accuracy of the implicit solves happening in each Runge--Kutta stage.
I speculate that what you're seeing is primarily due to the second part.
f you want you can explicitly control those separately via e.g. Kvaerno4(root_finder=VeryChord(rtol=..., atol=...))
. Give that a try and see?
C.f. also lines like this for how the stepsize controller tolerances are inherited:
I figured it was actually because of single-precision. Switching to double-precision floating-points fixed the problem, and I now get that the implicit solvers are much faster than the explicit ones, as expected.
tol = 1e-6
# single-precision
print(solve(dx.Tsit5(), tol=tol).stats["num_steps"]) # 4039
print(solve(dx.Kvaerno5(), tol=tol).stats["num_steps"]) # 11902
print(solve(dx.Kvaerno3(), tol=tol).stats["num_steps"]) # 27031
# double-precision
jax.config.update("jax_enable_x64", True)
print(solve(dx.Tsit5(), tol=tol).stats["num_steps"]) # 4039
print(solve(dx.Kvaerno5(), tol=tol).stats["num_steps"]) # 24
print(solve(dx.Kvaerno3(), tol=tol).stats["num_steps"]) # 90
I didn't expect the single-precision to limit solving for these kinds of tolerences, but there we go :) Thanks for the help!
Hello @patrick-kidger, I was investigating the implicit solvers of diffrax for a possible integration into dynamiqs. However, I am seeing that the performance of
Kvaerno3
andKvaerno5
is highly dependent on the choice ofatol
andrtol
. Is this expected? Or could this be a bug somewhere?Here is a MWE, for a typical stiff ODE we want to solve:
As you can see, for
Tsit5
, the number of steps (and thus the solver performance) is very stable with increasing tolerences. However, forKvaerno5
andKvaerno3
, increasing the tolerence by two orders of magnitudes makes the solver completely explode.