Open andyElking opened 3 months ago
So I don't think this should ever be an infinite loop -- as the next time around then prev_dt
(pulled out of the controller state) should be 0.05
, and we'll keep on shrinking the step.
I do take your point that if the previous step was rejected, we shouldn't use the prev_dt
out of the controller state. In this case we should use the t1 - t0
that actually caused a step rejection. I'd be happy to have a tweak to this effect. (I think we need a test for this too! This is subtle.)
As for why we continue to use prev_dt
, the reason is to handle the case in which prev_dt=<large>
, but jump_ts
clips it to <small>
, and the step is accepted. In this case, we don't want the steps after the jump to have to slowly work their way back up to a good step size. This is particularly troublesome for problems which have many jumps. If it would only take 1 step to move between jumps given prev_dt
, but we'd need to spend 3 steps working back up to the 'proper step size' given the clipped t1 - t0
, then we'd end up tripling the number of steps required.
Whilst we're here I will note that there is one other difference between jump_ts
and step_ts
, and that is that the former causes the FSAL property of e.g. a Runge--Kutta method to discard the saved function evaluation, whilst in the latter case we can continue to use it.
So I don't think this should ever be an infinite loop -- as the next time around then
prev_dt
(pulled out of the controller state) should be0.05
, and we'll keep on shrinking the step.
Oh right, I missed that.
I do take your point that if the previous step was rejected, we shouldn't use the
prev_dt
out of the controller state. In this case we should use thet1 - t0
that actually caused a step rejection. I'd be happy to have a tweak to this effect. (I think we need a test for this too! This is subtle.)
Sounds good, I'll include that change in the JumpStepControllerWrapper
and we can discuss in more detail once I make the PR.
As for why we continue to use
prev_dt
, the reason is to handle the case in whichprev_dt=<large>
, butjump_ts
clips it to<small>
, and the step is accepted. In this case, we don't want the steps after the jump to have to slowly work their way back up to a good step size. This is particularly troublesome for problems which have many jumps. If it would only take 1 step to move between jumps givenprev_dt
, but we'd need to spend 3 steps working back up to the 'proper step size' given the clippedt1 - t0
, then we'd end up tripling the number of steps required.
I understand and I agree. However:
prev_dt
when the step was clipped due to jump_ts
, I feel like we should also use it when it was clipped due to step_ts
.max((t1-t0)*factor, prev_dt)
in this case? After all I feel like three steps that are too small but accepted are still be better than three steps that are too big and rejected. And if the error was small, than factor
should already be large anyway, no?Whilst we're here I will note that there is one other difference between
jump_ts
andstep_ts
, and that is that the former causes the FSAL property of e.g. a Runge--Kutta method to discard the saved function evaluation, whilst in the latter case we can continue to use it.
Thanks, that is very useful to know.
I would summarise the complete behaviour in 3 rules:
t1-t0 <= prev_dt
(we can explicitly check that with an eqx.error_if
), with inequality only when the step was clipped or if we hit the end of the integration interval (we do not explicitly check that, but I see no other way how inequality could arise here).next_dt
must be >=prev_dt
.next_dt
must be < t1-t0
.These can be implemented in a very simple way:
dt_proposal = factor*(t1 - t0) # note that if step is rejected, then factor<1
# Here comes the clipping between dt_min and dt_max
eqx.error_if(prev_dt, prev_dt < t1-t0, "prev_dt must be >= t1-t0")
dt_proposal = jnp.where(keep_step, jnp.maximum(dt_proposal, prev_dt), dt_proposal)
new_prev_dt = dt_proposal # this goes into controller state as prev_dt
# Here comes the clipping due to step_ts and jump_ts and the whole nextafter(nextafter()) business
This has the nice property that it factors well into a controller (which does the first two lines) and a JumpStepWrapper
which does all the rest. That means that made_jump
(used only for the purposes of the nextafter
business) and prev_dt
are both kept in the state of the JumpStepWrapper
and the inner controller never sees them.
WDYT?
On eqx.error_if
: I try to be use these really rarely, as they carry a performance overhead. In this case I don't think it's important enough to use it. (Besides that, note that ou have to use its return value if you want the check to run.)
I think what you've got sounds reasonable. Mulling it over, I think it should be possible to do something even simpler: change this line:
with
- prev_dt = jnp.where(made_jump, prev_dt, t1 - t0)
+ prev_dt = jnp.where(made_jump & keep_step, prev_dt, t1 - t0)
(and if need be I think this still factors apart, as you describe).
Thanks, that's good to know. I will keep the number of eqx.error_if
s to a minimum.
In practice it seems that your proposal doesn't lead to desirable behaviour. I compared our two approaches on a very simple example ODE and I was surprised how precisely the experiment echoed the issue I described in my first comment:
On the other hand if the step was clipped to a much smaller size than was intended (i.e.
t1-t0 << prev_dt
), then this will usually reflect in the error being small accordingly, resulting in a largefactor
(and the step being accepted). This means that(t1-t0)*factor
would be again a reasonable step-size proposal, whereasprev_dt*factor
would be disproportionately massive.
In addition the experiment shows that my solution completely fixes this issue. You can find the experiment here. And here you can see why my proposal makes it easier to separate the jump_ts
and step_ts
into a wrapper.
Ah, I see what you mean.
Okay, in this case I think maybe what we should do is simply remove dt
from the controller state altogether. Just always use t1 - t0
, nothing else.
I believe your suggestion amounts to preventing the step size form shrinking after an accepted step. For context some PID implementations exhibit this behaviour (e.g. torchdiffeq does this) but I recall deciding against this for Diffrax. It's a heuristic that I think helps some problems but hurts others.
Fair enough, I'll get rid of it then. On the flip side, as I mentioned in #484, it seems like it was possible for the step size to increase after rejecting (I discovered this in some unrelated experiment, where it just seemed to go on forever until max_steps was reached). Was this intentional or is it good that I now capped dt_proposal
at self.safety*(t1-t0)
when keep_step=False
?
That is definitely pretty weird! I'm willing to believe that it happens, though.
In fact, here's something interesting I came across whilst looking at this just now:
It seems we do prevent step shrinking after an accepted step after all! 😅
In light of this, maybe we should fix the case you just mentioned by also adding factormax = jnp.where(keep_step, self.factormax, 1)
?
(EDIT: I've now seen that you mentioned this in #484. Ignore me, you're way ahead of me!)
Well technically this prevents it only from shrinking below t1-t0
, which can still be smaller than prev_dt
. But as you said before, probably all of these choices can bring pluses and minuses in different cases.
Haha yes, I was just about to point you to #484 😊
Hi Patrick,
Am I correct in saying that the only differences between step_ts and jump_ts are the following:
jump_ts
cannot be integers, but must be floats (so that you can doprevbefore
andnextafter
)_clip_jump_ts
also returnsmade_jump
, which is used to determine whether we need to do_t1 = nextafter(nextafter(t1))
.But in addition to those discrepancies, it seems Diffrax treats them differently in one other way as well, which I am not sure I understand. Namely, the line below uses
prev_dt=prev_dt
if the step was clipped due to a jump, butprev_dt=t1-t0
if the step was clipped due tostep_ts
. I don't see why we should make a distinction between these two cases.https://github.com/patrick-kidger/diffrax/blob/b977dce78382a4a5dc3e84e4bf25dfa2a8ae2bb2/diffrax/_step_size_controller/adaptive.py#L561
I would go even further and say that the line should just say
prev_dt=t1-t0
in all cases. This is because the error of the current step depends ont1-t0
, rather than onprev_dt
, so I feel like keepingprev_dt
in controller state is not needed. Here is what could go wrong with the current setup:Say
prev_dt=0.1
, but due tojump_ts
it was clipped tot1-t0 = 0.01
. Also assume that the error was large and the step gets rejected and assume that the controller computesfactor=0.5
. Then the next step-size proposal will be0.05
, which is bigger than the step that was just taken, so it will again be clipped byjump_ts
to0.01
, resulting in an infinite loop. Instead the new step proposal should just be(t1-t0)*factor = 0.005
, which would presumably result in a smaller error and move forward.On the other hand if the step was clipped to a much smaller size than was intended (i.e.
t1-t0 << prev_dt
), then this will usually reflect in the error being small accordingly, resulting in a largefactor
. This means that(t1-t0)*factor
would be again a reasonably large step-size proposal, whereasprev_dt*factor
would be disproportionately massive.Let me know if I missed something.