google-research / torchsde

Differentiable SDE solvers with GPU support and efficient sensitivity analysis.
Apache License 2.0
1.56k stars 196 forks source link

Further Python-end optimizations #23

Closed lxuechen closed 3 years ago

lxuechen commented 4 years ago
mtsokol commented 4 years ago

So the main issue to remove re-evaluation of g in those methods by passing already computed g is that for ForwardSDE we need it evaluated with t and for AdjointSDE with -t, right?

patrick-kidger commented 4 years ago

I don't think so. (The switching of signs is handled by ForwardSDE/AdjointSDE, not the solver.)

The issue is that AdjointSDE doesn't define a g at all! It only defines g_prod. It turns out for the AdjointSDE that g_prod is efficient to compute whilst g isn't.

Thus the most efficient way of doing things for the forward pass would be to evaluate g once, and then pass it in later, whilst for the backward pass the most efficient thing is just to evaluate g_prod multiple times. We don't yet do this optimisation on the forward pass, but we do the optimised thing on the backward pass.

I think the only solver that is currently a bit inefficient in this way is grad-free Milstein. (Correct me if I've missed one.) Thus I think the neatest way to fix this would be to add an optional g argument to g_prod, which grad-free Milstein uses. Then ForwardSDE.g_prod checks if that argument was passed and evaluates g if not, whilst AdjointSDE.g_prod does an assert that g is None, as grad-free Milstein doesn't work for the adjoint SDE anyway.