google-research / torchsde

Differentiable SDE solvers with GPU support and efficient sensitivity analysis.
Apache License 2.0
1.52k stars 195 forks source link

Lots of code reuse between solvers, and between adjoint SDEs #13

Closed patrick-kidger closed 3 years ago

patrick-kidger commented 3 years ago

To be clear, not raising this because I feel like being nitpicky - read on!

Other than very minor differences (e.g. order), the code for each Euler--Maruyama/Milstein/SRK method is essentially the same across noise types. Instead one can have e.g. a single Euler that accepts order as an argument, and four different wrappers that set that argument.

Similarly, the adjoint SDE definitions have lots of code reuse; it'd easier to write down a single adjoint SDE for the general noise case. (And any unnecessary computation can be masked out with if statements checking the noise type.)

I think the latter in particular offers an easy opportunity to unify Ito and Stratonovich: it'd just be another if statement checking if the correction term should be applied.

As for why I'm raising this - I'm interested in:

The first is aided by tidying up the existing code for solvers (I don't want to define MyCustomSolverAdditive, MyCustomSolverGeneral,...), the second is made possible by improving the adjoint SDE, and the third is a stretch goal that'll probably be easier if it can just be implemented with if sde_type= corrections in the right places.

Tidying up the solvers at least is probably something I can offer a PR on, but I wanted to check if this fit your general vision, or if you think there's anything that might go wrong, as it pretty much involves removing the whole methods directory structure.

lxuechen commented 3 years ago

To be clear, not raising this because I feel like being nitpicky - read on!

No problem! Please do and I'm happy to help in any way that I can!

Other than very minor differences (e.g. order), the code for each Euler--Maruyama/Milstein/SRK method is essentially the same across noise types. Instead one can have e.g. a single Euler that accepts order as an argument, and four different wrappers that set that argument.

I agree. This part of the codebase has an element of being very research-y, and unfortunately, the refactoring I did last minute wasn't super complete. The euler.py in diagonal is kinda redundant, and at the very least it should follow the dependency-injection-style constructions as the other classes.

I'm for any changes with regards to the euler methods that don't break things.

Similarly, the adjoint SDE definitions have lots of code reuse; it'd easier to write down a single adjoint SDE for the general noise case. (And any unnecessary computation can be masked out with if statements checking the noise type.)

This is slightly more tricky. The form of g_prod is clear to me for the adjoint of the general case. What's not super clear to me is the form of gdg_prod. This function would also need to take in estimates of the Levy area, which the current codebase doesn't yet support and is a slightly more involved to make work. Because of this, I'm not totally in favor of the prospect that "any unnecessary computation can be masked out with if statements checking the noise type". Masking out the computation would probably be a hard task (at the very least, it's a hard task for me atm), and this is on top of the fact that function signatures of gdg_prod would probably be different for general vs diagonal.

as it pretty much involves removing the whole methods directory structure.

At this point, I'm still in favor of keeping this directory. I think it should be clear that there are solvers which are just specific to certain SDEs, and the point is we probably shouldn't preclude such solvers coming in. As a compromise, what I think might make sense is to "merge" the methods and methods_strat folders, since there will be certain solvers shared by both SDE types, e.g. additive noise SDEs. I also think it might be worthwhile to keep the structure in the methods folder, separating different SDE classes, with the solvers possibly shared by different SDEs being directly under methods, as opposed to under a subfolder.

Solving the adjoint SDE with general noise. (And I'm happy to do so with Euler--Maruyama, knowing the strong order is 0.5. It should still work, right?)

A side comment about this is that when we initially tried this, it was quite hard to actually get the adjoint gradients to match the finite-diff ones. A lot of steps are typically needed. Love to discuss more if you have some particular SDE in mind.

The overall message I have is

What do you think?

lxuechen commented 3 years ago

An additional note is that it might not be straight forward to come up with an efficient implementation of the stratonovich correction term for general noise.

patrick-kidger commented 3 years ago

If you see this first, check your email :)

patrick-kidger commented 3 years ago

Closing this as we now have more specific plans.

mtsokol commented 3 years ago

@patrick-kidger Hi! I've recently started learning this repository and while implementing Heun's method in #4 I've also noticed some parts where logic can be unified to avoid some duplication, especially right now in the process of introducing support for Stratonovich. I'd love to continue contributing, I've tried to follow up with your discussion in #15 regarding this general refactor - If you have spare task that I could try please let me know.

patrick-kidger commented 3 years ago

@mtsokol Hello there. So I'm not completely sure that I'm the one you should be asking - this is really @lxuechen's repository, not mine!

That said - we're currently in the process of rewriting large portions of the internals - see the dev branch. In particular we should now have support for Stratonovich SDEs via the midpoint method, so you could look at formatting your work on Heun's method in the same manner, and then submit a PR there?

Glancing at what you've already done so far - I see you've done a diagonal-only version of Heun's method. I think this should be applicable to any noise type, so you could also look at generalising that (again, c.f. how midpoint is implemented).