patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.37k stars 124 forks source link

Diffrax activating the event on a rejected step #464

Open etienney opened 1 month ago

etienney commented 1 month ago

Hi

I'm using diffrax 0.5.1 with events like that :

def condition(state, **kwargs):
"some condition"
...
return cdn
event = dx.DiscreteTerminatingEvent(cond_fn=condition)
solution = dx.diffeqsolve(
...
discrete_terminating_event=event 
)

and I noticed that diffrax is activating the terminating event even if the step has been rejected. I suppose this is an undesirable behavior ? Also as far as I tried I've been unable to find a way to get the knowledge of whether the step has been rejected or not in the event through "state" or "kwargs"

I'm sorry to be unable to give a minimal working exemple here because I don't now how to reproduce it by making an event trigger at the "right moment" that is at the moment a time step has just been rejected ! (and the code I'm using is too large to be considered a working exemple). Though I'm pretty sure it happens as described since solutions.stats gives me 'num_accepted_steps': Array(0, dtype=int32), 'num_rejected_steps': Array(1, dtype=int32) and my event is triggered

patrick-kidger commented 1 month ago

Hmmm! That's not good!

@cholberg - WDYT?

lockwo commented 1 month ago

0.5.1 is pre the events refactor right?

etienney commented 1 month ago

0.5.1 is pre the events refactor right?

Yes it is ! ( Though I'm still using this version for one main reason that is state.tnext is not defined in 0.6.0 alongside state.tprev. And i need a dt = state.tnext - state.tprev for my condition. But that's another story. ). Though the code for 0.6.0 assumes backward compatibility with TerminatingEvent according to this note, so I suppose the problem is still relevant as for now.

cholberg commented 1 month ago

As far as I can tell, this happens exactly because your condition depends on tnext. This should be possible right now in 0.6.0 (at least not in a straight-forward way) as you mention yourself. But I think there should be a way to do what you want in the newest version that perhaps would not run into the same problems. The idea I had in mind would be to add an extra variable to the state of your differential equation that keeps track of the time. So the vector field for your ODETerm would look something like this:

def vf(t, y, args):
    t, x = y
    x_out = ...
    t_out = 1
    return (t_out, x_out)

@patrick-kidger: On a different note, this made me notice that the current implementation is actually not quite backwards compatible since DiscreteTerminatingEvent takes state as the first positional arguement which is not a possible argument in the newest version. I am not sure if this is something we would want to change? (Then, of course, you would not have to do the work around as described above if you wanted your event condition to depend on dt.)

patrick-kidger commented 1 month ago

@etienney: So tnext was removed deliberately. It's not clearly defined in general: for example when doing a root-find over the length of the final step, to determine the exact time at which a condition is triggered.

What's your use case for needing the length of a step?


@cholberg: indeed, removing this was deliberate (and despite our backward compatibility layer, the break here is the reason I bumped the minor version to 0.6.0, rather than continue down 0.5.x). The state object we were passing is an internal abstraction that isn't stable from release to release -- at times we may add or remove fields. It's also not clearly defined when doing a root find to find an event.

If there's any additional information we could reasonable be forward to the event, that we aren't at the moment, then I think that should become an additional **kwarg?

WDYT?

etienney commented 1 month ago

As far as I can tell, this happens exactly because your condition depends on tnext. This should be possible right now in 0.6.0 (at least not in a straight-forward way) as you mention yourself. But I think there should be a way to do what you want in the newest version that perhaps would not run into the same problems. The idea I had in mind would be to add an extra variable to the state of your differential equation that keeps track of the time. So the vector field for your ODETerm would look something like this:

def vf(t, y, args):
    t, x = y
    x_out = ...
    t_out = 1
    return (t_out, x_out)

Indeed I envisaged this possibility, though I fear that it would make the computation of vf more expensive to have two variables, when I'm working on a numerical solver where rapidity of execution is key. So this is maybe not the best solution for me.

@etienney: So tnext was removed deliberately. It's not clearly defined in general: for example when doing a root-find over the length of the final step, to determine the exact time at which a condition is triggered.

What's your use case for needing the length of a step?

My use case is the one of issue 462 where I made a MWE (2nd message). Let's say aux is the derivative of a function that I need to calculate as such ( For reasons of structure preservingness ). I need to have not the derivative but the integral of such function for my event so I need a dt.

patrick-kidger commented 1 month ago

I think if you want an integral then the best way to do this is to take advantage of the fact that diffeqsolve is itself computing the integral y(T) = y(T) + \int_0^t vector_field(t, y, args) dt! So augment your vector field with the quantity you would like to compute, and then pipe that value into your event function.

In particular you shouldn't need access to dt or to compute any kind of integral yourself.

etienney commented 1 month ago

Yes, this is indeed the obvious and perfect solution for me ! But as I said this is not possible. The "real" y (not aux) i'm integrating sould be integrated via a Tsit5 runge-kutta method. Those methods have some hypotheses, one of which is the fact that the solution should be at least C^k (ie it's k-th derivative should be continuous) with k the order of the RK scheme (k=4 for Tsit5 as it's RK4 verifying it's solution with a RK5 if I understood well). The aux i'm integrating has for its derivative the absolute value of something, which doesn't showcase this property. As such the solver goes crazy ( In practice, it can go crazy to the point of ouputing negative values for aux.. when aux starts at 0 and all its derivatives are absolute values of something, so clearly positive. And I'm not talking of some -1e-16 but some real -1e-4 ! ) So this was indeed my first try, but I moved away from it, and it motivated my previous issue.

cholberg commented 1 month ago

@cholberg: indeed, removing this was deliberate (and despite our backward compatibility layer, the break here is the reason I bumped the minor version to 0.6.0, rather than continue down 0.5.x). The state object we were passing is an internal abstraction that isn't stable from release to release -- at times we may add or remove fields. It's also not clearly defined when doing a root find to find an event.

If there's any additional information we could reasonable be forward to the event, that we aren't at the moment, then I think that should become an additional **kwarg?

WDYT?

Sorry, totally missed this. I've been away on holiday for a bit. But yea, I think that makes a lot of sense. Adding an extra **kwargs should be an easy change.

patrick-kidger commented 1 month ago

As such the solver goes crazy

Do you mean with a fixed step size (the conditions of the RK method are not method and so it does numerically iffy things) or with an adaptive step size (the step size controller struggles due to the aux not being Ck)?

If the former then the appropriate thing to do would be to write your own solver (subclass AbstractSolver), wrapping together e.g. Tsit5 and Euler. Apply the former to y and the latter to aux.

If the latter then you can pass a PIDController(norm=...) that excludes aux from consideration when determining how to compute step sizes.

etienney commented 1 month ago

It is a little bit of both, ie I'm using an adaptative step size, but the problem does not come from the step size controller struggling due to aux not being Ck but rather that

the conditions of the RK method are not met (the word you intended to choose right? you wrote method so i'm unsure) and so it does numerically iffy things

But yeah I can indeed modify my code to work with my own solver, then updating my code to fit the 0.6.0... I'll see if I ever find a simpler solution but thanks for this one in any case !!