google-research / torchsde

Differentiable SDE solvers with GPU support and efficient sensitivity analysis.
Apache License 2.0
1.56k stars 196 forks source link

Add efficient gdg_jvp term for log-ODE schemes. #20

Closed lxuechen closed 4 years ago

lxuechen commented 4 years ago

I've decided to go with the jvp implementation.

I implemented two versions: One with better memory scaling and the other with better time scaling. I'll send some more discussions via email.

I've only added a simple test. The numerical checks pass. The efficient version is much faster than the loop version on CPU.

lxuechen commented 4 years ago

A side note is that we should squash the commits when we merge this.

patrick-kidger commented 4 years ago

Also, I think the gdg_jacobian_contraction can be implemented by doing a loop / batch trick [my terminology for your fast Jacobian trick] over vjp rather than jvp, which would be worth doing what with PyTorch.

lxuechen commented 4 years ago

Also, I think the gdg_jacobian_contraction can be implemented by doing a loop / batch trick [my terminology for your fast Jacobian trick] over vjp rather than jvp, which would be worth doing what with PyTorch.

Feel free to give this a go.