google-research / torchsde

Differentiable SDE solvers with GPU support and efficient sensitivity analysis.
Apache License 2.0
1.52k stars 195 forks source link

Heun's method #24

Closed mtsokol closed 3 years ago

mtsokol commented 3 years ago

@patrick-kidger Hi!

As suggested it's PR for dev version with all latest comments applied from previous PR. Few questions:

  1. I pushed diagnostic/stratonovich_diagonal.py and modified problems.py to verify whether it's correctly implemented. Do we want it merged or should I remove it from PR? (right now it's rather draft file)
  2. SDEIto and SDEStratonovich introduce only param for sde_type. As ForwardSDEIto API suited my case for Stratonovich case I changed it to ForwardSDE which inherits from BaseSDE with sde_type param. (It's just a draft - changed only there).

So will Forward Stratonovich SDE API differ from ForwardSDEIto or can we unify it that way?

mtsokol commented 3 years ago

Ahh, OK I see that 2. has just been merged 😄 Let me rebase it!

mtsokol commented 3 years ago

@patrick-kidger OK, now it's rebased and with python3 -m diagnostics.stratonovich_diagonal plots can be obtained.

patrick-kidger commented 3 years ago

Overall this looks like it's come along nicely. Good work.

lxuechen commented 3 years ago

Thanks for redirecting the previous PR here to be on dev, @mtsokol. I was a little busy in getting other things in shape.

Also, thanks for taking care of this and making sure things are on the right track, @patrick-kidger.

I'll also give the code a thorough pass and leave some comments. We can merge this when all the comments are resolved.

patrick-kidger commented 3 years ago

Other than the (very minor) outstanding points this LGTM.

patrick-kidger commented 3 years ago

@lxuechen are you happy to merge this?

lxuechen commented 3 years ago

Looking good! Merging now.

Thanks again for being so patient!

patrick-kidger commented 3 years ago

Nice work @mtsokol. From what Chen says above it sounds like there's some official-Google things that might delay setting up CI. It sounds like you want to keep helping out (thankyou!), so if you're interested I can think of a few other projects?

mtsokol commented 3 years ago

Great! Thank you both for assistance!

@patrick-kidger I would love to continue working on this! If you could give me e.g. 2-3 tasks (when one waits for review I can switch to the second one) I would be thankful. Do we want next Stratonovich solver? (Looking at DiffEq they've got Milstein for Strat which formula differs in one place and also SRA in that variant).

I do not hold BSc in Math but in CS, although if you would have a task that would require such knowledge and provide a link to sources that I can learn from I can give it a shot also.

patrick-kidger commented 3 years ago

@mtsokol Excellent; glad to have you on board.

So there's several new schemes that would be nice to add.

For now I'm inclined to suggest keeping Ito and Strat solvers separate even if there's code duplication, but I think the derivative/derivative-free Milstein methods could probably go together, and the choice between them decided by a flag in options.

Other than that:

More broadly speaking if there's there any particular functionality you see in some other package (e.g. DifferentialEquations.jl, but in anything else as well) that you'd like to replicate then that would also be of interest.

If there's any particular problems you're interested in solving - and perhaps find one if there isn't! - then seeing how torchsde handles solving that use case would be interesting. New functionality tends to get added primarily as a result of needing it to solve some problem. Literally how I got involved in the development of both torchdiffeq and torchsde, oops.

That's a decently long list of things so by all means pick and choose what seems most attractive to you! Very open to any suggestions on your part if you want to do something specific.

mtsokol commented 3 years ago

Hi again!

(In the following days I will spend some time on next solvers as there's separate issue with bullet points.)

I've seen that you've merged support for Stratonovich adjoint two days ago - I've tried to follow up with PR and seen you've settled there TODO list after that - if you have any spare task there that I could try then let me know.

Also regarding benchmarking which you've mentioned - in the paper there are few mentions of backprop through solver method described in [19] ("Smoking adjoints: Fast Monte Carlo greeks") with comparison to stochastic adjoint. Have you run exactly that model described in [19] to additionally compare both methods in that context? (If not - do you think an attempt to do so is worthwhile?)

lxuechen commented 3 years ago

Hi!

I've seen that you've merged support for Stratonovich adjoint two days ago - I've tried to follow up with PR and seen you've settled there TODO list after that - if you have any spare task there that I could try then let me know.

For sure! I think things now depend on #37, but I'm sure there'd be some interesting further developments going forward.

Have you run exactly that model described in [19] to additionally compare both methods in that context? (If not - do you think an attempt to do so is worthwhile?)

At the time of writing our paper, we didn't run the exact problems presented in [19], and I agree it'd be interesting to find out what the efficiency comparison is like for these problems as an academic endeavour.

Aside, if you're interested in some technical challenge regarding the modeling aspect, here's something a little straightforward: It would be interesting to see if we could replicate the qualitative results of latent_sde.py example without using the logqp interface. The idea central idea is that we should be able to do this just by augmenting the state of the forward dynamics with an extra dimension whose drift vector field is (f - h) / g and diffusion vector field is 0. So at the end, to compute the loss, we only need to aggregate the reconstruction loss that depends on the first d states of the output, and the KL-loss which should depend on the d+1 state of the new augmented system.

mtsokol commented 3 years ago

@lxuechen Thanks for advice - I've started paper draft of [19] benchmark and if you don't mind I will have few questions next week to make sure I understand those approaches correctly.

Regarding latent_sde.py - sure, I can try it. One question - for recent version on dev branch the command: python3 -m examples.latent_sde --adaptive --rtol 1e-2 --atol 1e-3 --train-dir ~/Desktop is failing with:

...
logqp = [torch.stack(logqp_i, dim=0) for logqp_i in logqp]
RuntimeError: stack expects each tensor to be equal size, but got [512] at entry 0 and [] at entry 2