patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.37k stars 124 forks source link

SDE - share computation common to drift and diffusion terms #455

Open vadmbertr opened 2 months ago

vadmbertr commented 2 months ago

Hi,

Thanks for the great library!

I'm solving a SDE of the form: $$d\mathbf{X}(t) = (\mathbf{u}(t, \mathbf{X}(t)) + (\nabla \cdot \mathbf{K})(t, \mathbf{X}(t)))dt + \mathbf{V}(t, \mathbf{X}(t)) \cdot d\mathbf{W}(t)$$ where $\mathbf{X}$ is the position vector, $\mathbf{u}$ the velocity vector field, $\mathbf{K}$ is the diffusivity tensor field and $\mathbf{K}=\frac{1}{2} \mathbf{V} \cdot \mathbf{V}^T$.

As you can see, both the drift and diffusion terms depend on the diffusivity tensor field $\mathbf{K}$. In pratice, I'm computing it from the velocity field $\mathbf{u}$. Is there a way to avoid precomputing $\mathbf{K}$ across the entire domain and instead compute it locally (ie. using only a small neighborhood arount $\mathbf{X}(t)$ ) at each integration step, without having to recompute it separately for both the drift and diffusion terms?

Vadim

patrick-kidger commented 2 months ago

Yup!

First of all, if you just write things out naively then the compiler should detect the common subexpression and optimise that out. (At least for most solvers. I think a handful of SDE-specific solvers might evaluate the drift and diffusion in two totally different places, and I don't think the compiler is smart enough to precompute and hoist intermediates in that way.)

If you want to be really really sure that you're only doing things in one place, then the appropriate thing to do would be not to use MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion)), but instead to provide a single ControlTerm(drift_and_diffusion, time_and_brownian_motion). (Take a look at the Terms page for discussion on how those work.)

vadmbertr commented 2 months ago

Thanks for the reply!

If I were to go with the ControlTerm approach, is the following (dummy) gist correct?

from diffrax import AbstractPath, ControlTerm, diffeqsolve, Dopri5, SaveAt, VirtualBrownianTree
import jax

class SDEControl(AbstractPath):
    t0 = None
    t1 = None
    brownian_motion: VirtualBrownianTree

    def evaluate(self, t0, t1=None, left=True, use_levy=False):
        return t1 - t0, self.brownian_motion.evaluate(t0=t0, t1=t1, left=left, use_levy=use_levy)

def vector_field(t, y, args):
    drift, diffusion = y
    y = drift + diffusion
    return y, jax.numpy.asarray([.5])

t0 = 0
t1 = 1
dt = .1
dt0 = .01

brownian_motion = VirtualBrownianTree(t0, t1, tol=1e-4, shape=(1,), key=jax.random.PRNGKey(0))
sde_control = SDEControl(brownian_motion=brownian_motion)
sde_term = ControlTerm(vector_field, sde_control)

solver = Dopri5()
saveat = SaveAt(ts=jax.numpy.arange(t0, t1 + dt, dt))

y0_drift = 1
y0_diffusion = 0

ys_drift, ys_diffusion = diffeqsolve(
    sde_term, solver, t0=t0, t1=t1, dt0=dt0, y0=(y0_drift, y0_diffusion), saveat=saveat
).ys
ys = ys_drift + ys_diffusion

My feeling is that it will be easier to generalise this simplify example by keeping track of the drift and diffusion terms separately and explicitly summing them when needed.

patrick-kidger commented 2 months ago

This looks mostly correct to me! In particular I think your combined control looks correct.

I think your choice of ys looks a little odd: you wouldn't normally have a separate evolving state for the drift and diffusion. In Diffrax, this is the object equivalent to the $\textbf{X}$ in your equation.

You probably don't want to use Dopri5() to solve an SDE. I believe it will work (and converge to the Stratonovich solution), but it'll be inefficient compared to a lower-order solver.

vadmbertr commented 2 months ago

I think your choice of ys looks a little odd: you wouldn't normally have a separate evolving state for the drift and diffusion. In Diffrax, this is the object equivalent to the X in your equation.

Yeah, I agree it's odd. I actually found that I can wrap the drift and diffusion terms outputted by vector_field inside a lineax.PyTreeLinearOperator. It's then quite easy to extend it to 2D (even though I was surprised that using a lineax.DiagonalLinearOperator in the PyTree does not produce the same result as using its materialized matrix as I would have expected).

You probably don't want to use Dopri5() to solve an SDE. I believe it will work (and converge to the Stratonovich solution), but it'll be inefficient compared to a lower-order solver.

About that, the downside of this approach (i.e., using a single ControlTerm) is that it can only be used with ODE solvers, right?

Thanks for the feedback!

patrick-kidger commented 2 months ago

Hmm, it's true that the single-term approach makes things harder to use with the SDE-specific solvers.

In general you can check solver.term_structure to see what's compatible.

Anyway, if nothing else consider using Heun -- this is compatible term-wise, it also converges to the Stratonovich solution, and is asymptotically just as efficient as Tsit5 when used to solve an SDE. :)