Closed lxuechen closed 3 years ago
So the main issue to remove re-evaluation of g
in those methods by passing already computed g
is that for ForwardSDE
we need it evaluated with t
and for AdjointSDE
with -t
, right?
I don't think so. (The switching of signs is handled by ForwardSDE
/AdjointSDE
, not the solver.)
The issue is that AdjointSDE
doesn't define a g
at all! It only defines g_prod
. It turns out for the AdjointSDE that g_prod
is efficient to compute whilst g
isn't.
Thus the most efficient way of doing things for the forward pass would be to evaluate g
once, and then pass it in later, whilst for the backward pass the most efficient thing is just to evaluate g_prod
multiple times. We don't yet do this optimisation on the forward pass, but we do the optimised thing on the backward pass.
I think the only solver that is currently a bit inefficient in this way is grad-free Milstein. (Correct me if I've missed one.) Thus I think the neatest way to fix this would be to add an optional g
argument to g_prod
, which grad-free Milstein uses. Then ForwardSDE.g_prod
checks if that argument was passed and evaluates g
if not, whilst AdjointSDE.g_prod
does an assert that g is None
, as grad-free Milstein doesn't work for the adjoint SDE anyway.
ForwardSDE
andAdjointSDE
forg_prod
andgdg_prod
/gdg_jvp
/gdg_jvp
.Dictionary lookup -> if/else statements after profiling.(I know you put this in because I commented on it, but it's not important enough that I'm fussed. - Patrick)