google-research / torchsde

Differentiable SDE solvers with GPU support and efficient sensitivity analysis.
Apache License 2.0
1.56k stars 195 forks source link

Added reversible Heun solver #97

Closed patrick-kidger closed 3 years ago

patrick-kidger commented 3 years ago

Code

Documentation

Examples

Tests

Note that there's still a few TODOs against the arXiv reference for the new paper. I'm opening this as a PR so you can review it now if you get time, but no time pressure though. I'll fill in the reference once the paper is on arXiv, which will be on Monday.

lxuechen commented 3 years ago

Thanks for the heads-up. I'll work on this on Sunday.

patrick-kidger commented 3 years ago

@lxuechen do you have an ETA for reviewing this? (+Once this PR is in I'd like to announce the paper on Twitter.)

lxuechen commented 3 years ago

@lxuechen do you have an ETA for reviewing this? (+Once this PR is in I'd like to announce the paper on Twitter.)

I will send in comments on code by Friday. I'll send in another email about the paper. Sorry for the delay.

patrick-kidger commented 3 years ago

No worries, thanks.

patrick-kidger commented 3 years ago

Thanks for the review. I'm glad things seem good overall.

Agreed on all comments I haven't otherwise responded to. I'll make changes shortly.

Major answers:

  1. So I don't think any of the current state in the current solvers is ever mutated. (IMO a good thing.) I was looking to preserve this property.

Generally I'm a fan of explicitly managing hidden state in loops, in something of a functional style, instead of saving them in a mutating reference. I find it much easier to reason about.

  1. I think (only) the SRK and log-ODE methods care about making a distinction between batch and channel dimensions, since they care about some kind of Levy area-like quantity.

However both of these are (coincidentally?) already verboten for adjoint SDEs as they need direct access to diffusions rather than diffusion-vector products.

In short I think we lucked out, but IMO this speaks to a need to redesign things somehow. I've not thought through the exact details of how.

  1. Two main things to think about.

One, we'd need to record every step location for the forward pass. (O(T) memory in theory but in practice should probably be fine.) Then on the adjoint pass we'd need to step to those same locations, in reverse order, to properly reconstruct things.

Two, I believe James' current theory is that reversible Heun will (only?) converge with adaptive step sizes if pairs of adjacent steps are the same size. So we'd need to alternate the adaptive stepping on/off for each step.

Both should be totally doable. Admittedly not something I'm hoping to get done any time soon though; currently writing the thesis is my priority.

lxuechen commented 3 years ago

However both of these are (coincidentally?) already verboten for adjoint SDEs as they need direction access to diffusions rather than diffusion-vector products. In short I think we lucked out, but IMO this speaks to a need to redesign things somehow. I've not thought through the exact details of how.

I think this is the case. Though, I'd still prefer we avoid the extra squeeze and unsqueezing unless there's an explicit solver for the adjoint which involves this issue.

I'll take another look tomorrow just to double check.

lxuechen commented 3 years ago

Looks good! Regarding the functional-style vs OO-style thing, I think we'll encounter this discussion again when dealing with adaptive solvers, since the steps need to be recorded somehow, and it's probably easiest to do this via an internal state of the solver.

I also plan to write some minor but helpful docstrings for the code later next week.

patrick-kidger commented 3 years ago

Thank you, I'm very happy to get this in!