Open andyElking opened 3 months ago
I now also added the functionality to revisit rejected steps. In addition, I also imporved the runtime of step_ts
and jump_ts
, because the controller no longer searches the whole array each time, but keeps an index of where in the array it was previously.
Also I think there was a bug in the PID controller, where it would sometimes reject a step, but have factor>1
. To remedy this I modified the following:
https://github.com/patrick-kidger/diffrax/blob/501bed52fb3726aaee41798aaab46d2bdff1a680/diffrax/_step_size_controller/adaptive.py#L569-L574
I think possibly something smaller than just self.safety
would make even more sense, I feel like if a step is rejected the next step should be at least 0.5x smaller. But I'm not an expert.
I added a test for revisiting steps and it all seems to work. I also sprinkled in a bunch of eqx.error_if
statements to make sure the necessary invariants are always maintained. But this is a bit experimental, so maybe there are some bugs I didn't test for.
I think I commented the code quite well, so hopefully you can easily notice if I made a mistake somewhere.
P.S.: Sorry for bombarding you with PRs. As far as I'm concerned this one is very low priority, I can use the code even if it isn't merged into diffrax proper.
Hi @patrick-kidger,
I got rid of some eqx.error_if
s that I added to my JumpStepWrapper
and redid the timing benchmarks. My new implementation was already faster than the old PIDController
before, but now this is way more significant, especially when step_ts
is long (think >100
). Surprisingly, it is faster even when it has to revisit rejected steps. See https://github.com/patrick-kidger/diffrax/blob/345e23ad34abe095d25b8be8d621360afd592165/benchmarks/jump_step_timing.py#L126-L128
Thanks for the review! I made all the edits I could and I left some comments where I need guidance (no hurry though, this is not high priority for me). Also, should I get rid of prev_dt
entirely, as you suggested in #483?
Also, should I get rid of prev_dt entirely, as you suggested in #483?
If it's easy to do that in a separate commit afterwards then I would say yes. A separate commit just so it's easy to revert if it turns out we were wrong about something here :D
Thanks for the review, Patrick! I'll probably make the fixes sometime in the coming week. I am also making progress on the ML examples for the Single-seed paper, but it is slower now, due to my internship.
Hi Patrick,
I factored the
jump_ts
andstep_ts
out of thePIDController
intoJumpStepWrapper
(I'm not very set on this name, lmk if you have ideas). I also made it behave as we discussed in #483. In particular, the following three rules are maintained:t1-t0 <= prev_dt
(this is checked viaeqx.error_if
), with inequality only if the step was clipped or if we hit the end of the integration interval (we do not explicitly check for that).next_dt
must be>=prev_dt
.next_dt
must be< t1-t0
.We achieve this in a very simple way here: https://github.com/patrick-kidger/diffrax/blob/78b122adf39b2f8d26a79d0ac239a2fb675653a1/diffrax/_step_size_controller/jump_step_wrapper.py#L119-L123
The next step is to add a parameter
JumpStepWrapper.revisit_rejected_steps
which does what you expect. That will appear in a future commit in this same PR.