Closed patrick-kidger closed 3 years ago
Thanks for the heads-up. I'll work on this on Sunday.
@lxuechen do you have an ETA for reviewing this? (+Once this PR is in I'd like to announce the paper on Twitter.)
@lxuechen do you have an ETA for reviewing this? (+Once this PR is in I'd like to announce the paper on Twitter.)
I will send in comments on code by Friday. I'll send in another email about the paper. Sorry for the delay.
No worries, thanks.
Thanks for the review. I'm glad things seem good overall.
Agreed on all comments I haven't otherwise responded to. I'll make changes shortly.
Major answers:
Generally I'm a fan of explicitly managing hidden state in loops, in something of a functional style, instead of saving them in a mutating reference. I find it much easier to reason about.
However both of these are (coincidentally?) already verboten for adjoint SDEs as they need direct access to diffusions rather than diffusion-vector products.
In short I think we lucked out, but IMO this speaks to a need to redesign things somehow. I've not thought through the exact details of how.
One, we'd need to record every step location for the forward pass. (O(T) memory in theory but in practice should probably be fine.) Then on the adjoint pass we'd need to step to those same locations, in reverse order, to properly reconstruct things.
Two, I believe James' current theory is that reversible Heun will (only?) converge with adaptive step sizes if pairs of adjacent steps are the same size. So we'd need to alternate the adaptive stepping on/off for each step.
Both should be totally doable. Admittedly not something I'm hoping to get done any time soon though; currently writing the thesis is my priority.
However both of these are (coincidentally?) already verboten for adjoint SDEs as they need direction access to diffusions rather than diffusion-vector products. In short I think we lucked out, but IMO this speaks to a need to redesign things somehow. I've not thought through the exact details of how.
I think this is the case. Though, I'd still prefer we avoid the extra squeeze and unsqueezing unless there's an explicit solver for the adjoint which involves this issue.
I'll take another look tomorrow just to double check.
Looks good! Regarding the functional-style vs OO-style thing, I think we'll encounter this discussion again when dealing with adaptive solvers, since the steps need to be recorded somehow, and it's probably easiest to do this via an internal state of the solver.
I also plan to write some minor but helpful docstrings for the code later next week.
Thank you, I'm very happy to get this in!
Code
y
, this means that solvers can now carry along some extra state. In particular this state is initialised outside of_SdeintAdjointMethod
so that the gradients through it are correct. This is why the code for adjoints has been adjusted in the way that it has, to make that possible.Documentation
logqp
option -- my assumption is that we don't need to deprecate this, now that we have a relatively efficient+maintainble way of handling it.Examples
Tests
test_againt_sdeint
that compares gradients forsdeint_adjoint
againstsdeint
.Note that there's still a few TODOs against the arXiv reference for the new paper. I'm opening this as a PR so you can review it now if you get time, but no time pressure though. I'll fill in the reference once the paper is on arXiv, which will be on Monday.