patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.44k stars 129 forks source link

Why are step_ts and jump_ts treated differently here? #483

Open andyElking opened 3 months ago

andyElking commented 3 months ago

Hi Patrick,

Am I correct in saying that the only differences between step_ts and jump_ts are the following:

But in addition to those discrepancies, it seems Diffrax treats them differently in one other way as well, which I am not sure I understand. Namely, the line below uses prev_dt=prev_dt if the step was clipped due to a jump, but prev_dt=t1-t0 if the step was clipped due to step_ts. I don't see why we should make a distinction between these two cases.

https://github.com/patrick-kidger/diffrax/blob/b977dce78382a4a5dc3e84e4bf25dfa2a8ae2bb2/diffrax/_step_size_controller/adaptive.py#L561

I would go even further and say that the line should just say prev_dt=t1-t0 in all cases. This is because the error of the current step depends on t1-t0, rather than on prev_dt, so I feel like keeping prev_dt in controller state is not needed. Here is what could go wrong with the current setup:

Say prev_dt=0.1, but due to jump_ts it was clipped to t1-t0 = 0.01. Also assume that the error was large and the step gets rejected and assume that the controller computes factor=0.5. Then the next step-size proposal will be 0.05, which is bigger than the step that was just taken, so it will again be clipped by jump_ts to 0.01, resulting in an infinite loop. Instead the new step proposal should just be (t1-t0)*factor = 0.005, which would presumably result in a smaller error and move forward.

On the other hand if the step was clipped to a much smaller size than was intended (i.e. t1-t0 << prev_dt), then this will usually reflect in the error being small accordingly, resulting in a large factor. This means that (t1-t0)*factor would be again a reasonably large step-size proposal, whereas prev_dt*factor would be disproportionately massive.

Let me know if I missed something.

patrick-kidger commented 3 months ago

So I don't think this should ever be an infinite loop -- as the next time around then prev_dt (pulled out of the controller state) should be 0.05, and we'll keep on shrinking the step.

I do take your point that if the previous step was rejected, we shouldn't use the prev_dt out of the controller state. In this case we should use the t1 - t0 that actually caused a step rejection. I'd be happy to have a tweak to this effect. (I think we need a test for this too! This is subtle.)

As for why we continue to use prev_dt, the reason is to handle the case in which prev_dt=<large>, but jump_ts clips it to <small>, and the step is accepted. In this case, we don't want the steps after the jump to have to slowly work their way back up to a good step size. This is particularly troublesome for problems which have many jumps. If it would only take 1 step to move between jumps given prev_dt, but we'd need to spend 3 steps working back up to the 'proper step size' given the clipped t1 - t0, then we'd end up tripling the number of steps required.

Whilst we're here I will note that there is one other difference between jump_ts and step_ts, and that is that the former causes the FSAL property of e.g. a Runge--Kutta method to discard the saved function evaluation, whilst in the latter case we can continue to use it.

andyElking commented 3 months ago

So I don't think this should ever be an infinite loop -- as the next time around then prev_dt (pulled out of the controller state) should be 0.05, and we'll keep on shrinking the step.

Oh right, I missed that.

I do take your point that if the previous step was rejected, we shouldn't use the prev_dt out of the controller state. In this case we should use the t1 - t0 that actually caused a step rejection. I'd be happy to have a tweak to this effect. (I think we need a test for this too! This is subtle.)

Sounds good, I'll include that change in the JumpStepControllerWrapper and we can discuss in more detail once I make the PR.

As for why we continue to use prev_dt, the reason is to handle the case in which prev_dt=<large>, but jump_ts clips it to <small>, and the step is accepted. In this case, we don't want the steps after the jump to have to slowly work their way back up to a good step size. This is particularly troublesome for problems which have many jumps. If it would only take 1 step to move between jumps given prev_dt, but we'd need to spend 3 steps working back up to the 'proper step size' given the clipped t1 - t0, then we'd end up tripling the number of steps required.

I understand and I agree. However:

  1. If we choose to use prev_dt when the step was clipped due to jump_ts, I feel like we should also use it when it was clipped due to step_ts.
  2. Have you considered using something like max((t1-t0)*factor, prev_dt) in this case? After all I feel like three steps that are too small but accepted are still be better than three steps that are too big and rejected. And if the error was small, than factor should already be large anyway, no?

Whilst we're here I will note that there is one other difference between jump_ts and step_ts, and that is that the former causes the FSAL property of e.g. a Runge--Kutta method to discard the saved function evaluation, whilst in the latter case we can continue to use it.

Thanks, that is very useful to know.

patrick-kidger commented 3 months ago
  1. That also sounds reasonable to me.
  2. Indeed, I think other heuristics could be deployed here too. I'm a bit wary about adding heuristics that only trigger after jumps though (which IIUC is what you're suggesting) -- that just sounds like it's getting a bit tricky to reason about / to debug.
andyElking commented 3 months ago

I would summarise the complete behaviour in 3 rules:

  1. We always have t1-t0 <= prev_dt (we can explicitly check that with an eqx.error_if), with inequality only when the step was clipped or if we hit the end of the integration interval (we do not explicitly check that, but I see no other way how inequality could arise here).
  2. If the step was accepted, then next_dt must be >=prev_dt.
  3. If the step was rejected, then next_dt must be < t1-t0.

These can be implemented in a very simple way:

dt_proposal = factor*(t1 - t0)  # note that if step is rejected, then factor<1
# Here comes the clipping between dt_min and dt_max

eqx.error_if(prev_dt, prev_dt < t1-t0, "prev_dt must be >= t1-t0")

dt_proposal = jnp.where(keep_step, jnp.maximum(dt_proposal, prev_dt), dt_proposal)
new_prev_dt = dt_proposal  # this goes into controller state as prev_dt

# Here comes the clipping due to step_ts and jump_ts and the whole nextafter(nextafter()) business

This has the nice property that it factors well into a controller (which does the first two lines) and a JumpStepWrapper which does all the rest. That means that made_jump (used only for the purposes of the nextafter business) and prev_dt are both kept in the state of the JumpStepWrapper and the inner controller never sees them.

WDYT?

patrick-kidger commented 2 months ago

On eqx.error_if: I try to be use these really rarely, as they carry a performance overhead. In this case I don't think it's important enough to use it. (Besides that, note that ou have to use its return value if you want the check to run.)

I think what you've got sounds reasonable. Mulling it over, I think it should be possible to do something even simpler: change this line:

https://github.com/patrick-kidger/diffrax/blob/7384bfa1cd3222c2c5d8c705907e60bbf71587ec/diffrax/_step_size_controller/adaptive.py#L561

with

- prev_dt = jnp.where(made_jump, prev_dt, t1 - t0)
+ prev_dt = jnp.where(made_jump & keep_step, prev_dt, t1 - t0)

(and if need be I think this still factors apart, as you describe).

andyElking commented 2 months ago

Thanks, that's good to know. I will keep the number of eqx.error_ifs to a minimum.

In practice it seems that your proposal doesn't lead to desirable behaviour. I compared our two approaches on a very simple example ODE and I was surprised how precisely the experiment echoed the issue I described in my first comment:

On the other hand if the step was clipped to a much smaller size than was intended (i.e. t1-t0 << prev_dt), then this will usually reflect in the error being small accordingly, resulting in a large factor (and the step being accepted). This means that (t1-t0)*factor would be again a reasonable step-size proposal, whereas prev_dt*factor would be disproportionately massive.

In addition the experiment shows that my solution completely fixes this issue. You can find the experiment here. And here you can see why my proposal makes it easier to separate the jump_ts and step_ts into a wrapper.

patrick-kidger commented 2 months ago

Ah, I see what you mean.

Okay, in this case I think maybe what we should do is simply remove dt from the controller state altogether. Just always use t1 - t0, nothing else.

I believe your suggestion amounts to preventing the step size form shrinking after an accepted step. For context some PID implementations exhibit this behaviour (e.g. torchdiffeq does this) but I recall deciding against this for Diffrax. It's a heuristic that I think helps some problems but hurts others.

andyElking commented 2 months ago

Fair enough, I'll get rid of it then. On the flip side, as I mentioned in #484, it seems like it was possible for the step size to increase after rejecting (I discovered this in some unrelated experiment, where it just seemed to go on forever until max_steps was reached). Was this intentional or is it good that I now capped dt_proposal at self.safety*(t1-t0) when keep_step=False?

patrick-kidger commented 2 months ago

That is definitely pretty weird! I'm willing to believe that it happens, though.

In fact, here's something interesting I came across whilst looking at this just now:

https://github.com/patrick-kidger/diffrax/blob/0679807bf7bda0de3884a78cf20ce2dd4d8ade4c/diffrax/_step_size_controller/adaptive.py#L602

It seems we do prevent step shrinking after an accepted step after all! 😅

In light of this, maybe we should fix the case you just mentioned by also adding factormax = jnp.where(keep_step, self.factormax, 1)?

(EDIT: I've now seen that you mentioned this in #484. Ignore me, you're way ahead of me!)

andyElking commented 2 months ago

Well technically this prevents it only from shrinking below t1-t0, which can still be smaller than prev_dt. But as you said before, probably all of these choices can bring pluses and minuses in different cases.

Haha yes, I was just about to point you to #484 😊