patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.44k stars 130 forks source link

Added JumpStepWrapper #484

Open andyElking opened 3 months ago

andyElking commented 3 months ago

Hi Patrick,

I factored the jump_ts and step_ts out of the PIDController into JumpStepWrapper (I'm not very set on this name, lmk if you have ideas). I also made it behave as we discussed in #483. In particular, the following three rules are maintained:

  1. We always have t1-t0 <= prev_dt (this is checked via eqx.error_if), with inequality only if the step was clipped or if we hit the end of the integration interval (we do not explicitly check for that).
  2. If the step was accepted, then next_dt must be >=prev_dt.
  3. If the step was rejected, then next_dt must be < t1-t0.

We achieve this in a very simple way here: https://github.com/patrick-kidger/diffrax/blob/78b122adf39b2f8d26a79d0ac239a2fb675653a1/diffrax/_step_size_controller/jump_step_wrapper.py#L119-L123

The next step is to add a parameter JumpStepWrapper.revisit_rejected_steps which does what you expect. That will appear in a future commit in this same PR.

andyElking commented 3 months ago

I now also added the functionality to revisit rejected steps. In addition, I also imporved the runtime of step_ts and jump_ts, because the controller no longer searches the whole array each time, but keeps an index of where in the array it was previously.

Also I think there was a bug in the PID controller, where it would sometimes reject a step, but have factor>1. To remedy this I modified the following: https://github.com/patrick-kidger/diffrax/blob/501bed52fb3726aaee41798aaab46d2bdff1a680/diffrax/_step_size_controller/adaptive.py#L569-L574 I think possibly something smaller than just self.safety would make even more sense, I feel like if a step is rejected the next step should be at least 0.5x smaller. But I'm not an expert.

I added a test for revisiting steps and it all seems to work. I also sprinkled in a bunch of eqx.error_if statements to make sure the necessary invariants are always maintained. But this is a bit experimental, so maybe there are some bugs I didn't test for.

I think I commented the code quite well, so hopefully you can easily notice if I made a mistake somewhere.

P.S.: Sorry for bombarding you with PRs. As far as I'm concerned this one is very low priority, I can use the code even if it isn't merged into diffrax proper.

andyElking commented 3 months ago

Hi @patrick-kidger, I got rid of some eqx.error_ifs that I added to my JumpStepWrapper and redid the timing benchmarks. My new implementation was already faster than the old PIDController before, but now this is way more significant, especially when step_ts is long (think >100). Surprisingly, it is faster even when it has to revisit rejected steps. See https://github.com/patrick-kidger/diffrax/blob/345e23ad34abe095d25b8be8d621360afd592165/benchmarks/jump_step_timing.py#L126-L128

andyElking commented 3 months ago

Thanks for the review! I made all the edits I could and I left some comments where I need guidance (no hurry though, this is not high priority for me). Also, should I get rid of prev_dt entirely, as you suggested in #483?

patrick-kidger commented 1 week ago

Also, should I get rid of prev_dt entirely, as you suggested in #483?

If it's easy to do that in a separate commit afterwards then I would say yes. A separate commit just so it's easy to revert if it turns out we were wrong about something here :D

andyElking commented 1 week ago

Thanks for the review, Patrick! I'll probably make the fixes sometime in the coming week. I am also making progress on the ML examples for the Single-seed paper, but it is slower now, due to my internship.