patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.37k stars 124 forks source link

Making a simulation simulatneously on 2 different solvers #447

Open etienney opened 2 months ago

etienney commented 2 months ago

Hi,

First, thanks to those developing this library that is useful to my work. I want to compute an ODETerm running on two different solvers simultaneously. Is it possible ? Let me explain myself: I have a class State(eqx.Module): y: Array error: Array where I would like to compute y along a Tsit5 solver and error along a Euler one for instance. (I don't need a large precision over the error for the resolution, and a structure preserving schema, so running it on euler why running the y on Tsit5 would make sense) Why not running them separately ? Because the error is itended to be an posteriori estimates of the error made by the truncation in space for y (which theoreticaly lives in an infinite dimensional space, but has to be truncated to a finite one to be computed, which induces an approximation). If 'error' estimates that the truncature is poor at a moment t of the simulation, it goes out of diffrax via a DiscreteTerminatingEvent, widen the truncature for y, and return to diffrax at this t. I have $\frac{d(error)}{dt}=f(y)$" with f being a non trivial function. Thus I cannot run those things separately.

Is there a way to do such a thing ?

I hope i'm clear enough. Don't hesitate to tell if it's not the case.

lockwo commented 2 months ago

My go to approach (not sure if it is the simplest) would be to make a custom solver. If you define a custom step function, that returns a y update and uses a different method for computing the error that should (in code, not sure about theory) work. If you can discard the Tsit5 error you could just make it a wrapper on that then define a step function that calls to Tsit5 step, then discards the error, computes a new one and returns that. If you can't/don't want the Tsit5 error to be computed at all, solver approach should still work you might just have to write/copy more step code.

patrick-kidger commented 2 months ago

Agreed, this sounds like a use case for writing a custom solver. I'd recommend taking a look at the source code for a simple solver like Euler for an example of what that looks like.

You can of course arrange for your custom solver to just wrap an existing one (like Tsit5):

class MySolver(diffrax.AbstractSolver):
    solver: AbstractSolver

    def step(...):
        ... = self.solver.step(...)

solver = MySolver(Tsit5())

I'm not sure why you want to use Euler instead of Tsit5's embedded error estimate, though. Using the latter will actually be cheaper! Using the latter will actually be cheaper! All of the required function evaluations have already been made.

etienney commented 2 months ago

Thanks you both for your answers !

I'm not sure why you want to use Euler instead of Tsit5's embedded error estimate, though. Using the latter will actually be cheaper! Using the latter will actually be cheaper! All of the required function evaluations have already been made.

I'm not sure we are talking of the same thing. It looks like you are talking about the error estimates for the time dimension when I am talking about an error estimates for the space dimensions (ie the dimensions of the matrix). I do know that estimating the error for the step size in time for Tsit5 method is almost free. But I don't think it gives me informations on the error made by truncating the space. Am I missing something ?

Consequently I believe that wraping a solver and rewriting a step function will only make me look on the error relative to time ?

patrick-kidger commented 2 months ago

Ah, sorry -- I didn't understand that detail.

So I think you could still make this happen with a custom solver. A solver gets to return a result specifying whether things are proceeding acceptably. I think you shoudl be able to have your solver output diffrax.RESULTS.event_occurred, which will cause the solve to halt at that step.

etienney commented 2 months ago

Yes indeed ! I can get the knowledge of whether my condition is triggered (when error > a certain tolerance) via an event, the solver stops there and I get the knowledge of whether it stoped through my event via solution.event_mask for instance. And that's exactly what I do. Though to check my condition I have to check this error which I need to integrate along the trajectory since It's form is $|error(t)|_1 \leq |error(0)|_1 + \int_0^t |f_s(y(s))|_1$ To do this I have a custom ODETerm which is computing both $\dot{y}$ and $\dot{error}$ through a Tsit5 solver, and then my condition applies on state.sol.error to check the tolerance and apply the event etc.. The problem is it applies Tsit5 BOTH to compute y and error doing this. I would like to apply tsit5 to y and euler to error (for computation efficiency and structure preserving considerations)