patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.39k stars 123 forks source link

IID Brownian motion? #228

Open jajcayn opened 1 year ago

jajcayn commented 1 year ago

Hey there, First off, thanks a lot for this library! I wondered whether it is possible to solve a 2-dimensional SDE with IID noise? As a simple example, consider the Ornstein-Uhlenback process (I need it to drive other ODE, so I want to use sol_ou.evaluate(t) in the later vector field).

t0, t1 = 0.0, 20.0
DT = 0.001
dim = 2

def drift_ou(t, y, args):
    return (args["mu"] - y) / args["tau"]

def diffustion_ou(t, y, args):
    return jnp.ones((dim))*args["sigma"]

brownian_motion = dfx.VirtualBrownianTree(
    t0,
    t1,
    tol=1e-3,
    shape=jax.ShapeDtypeStruct((dim,), np.float64),
    key=jr.PRNGKey(1),
)

terms = dfx.MultiTerm(dfx.ODETerm(drift_ou), dfx.ControlTerm(diffustion_ou, brownian_motion))
args = {"mu": 1.0, "tau": 0.1, "sigma": 0.3}
sol_ou = dfx.diffeqsolve(
    terms,
    dfx.Heun(),
    t0,
    t1,
    dt0=DT,
    y0=jnp.ones((dim))*args["mu"],
    saveat=dfx.SaveAt(dense=True),
    max_steps=int(t1*(1./DT)),
    args=args
)

In this code, I am trying to create a 2-dimensional Ornstein-Uhlenback, and VirtualBrownianTree gives me the dW in 2 dimensions; however, the Brownian motion is the same for both dimensions. Is there some way to make the "samples" independent?

Cheers!

patrick-kidger commented 1 year ago

Hey there. I'm not sure I understand. Taking your example and then running e.g. brownian_motion.evaluate(0.1), I get [0.14833272 0.35743967], i.e. different values in each dimension.

It should be the case that each dimension is treated independently.

jajcayn commented 1 year ago

Thanks for the answer! I can confirm that when I do the same thing as you, I get different values. However, when I tried to model the whole Ornstein-Uhlenack, I got the same solutions for both dimensions:

import diffrax as dfx
import numpy as np
import matplotlib.pyplot as plt
import jax.random as jr
import jax.numpy as jnp

# times in seconds
dt = 0.001
t0 = 0.
tmax = 10.

def drift_ou(t, y, args):
    """
    Ornstein-Uhlenback drift term.
    """
    return (args["mu"] - y) / args["tau"]

def diffusion_ou(t, y, args):
    """
    Ornstein-Uhlenback diffusion (noise) term.
    """
    return jnp.ones((2)) * args["sigma"]

# this sets up random Brownian motion / Wiener process
brownian_motion = dfx.VirtualBrownianTree(
    t0,
    tmax,
    tol=1e-5,
    shape=jax.ShapeDtypeStruct((2,), np.float64),
    key=jr.PRNGKey(2),
)

sde_terms_ou = dfx.MultiTerm(
    dfx.ODETerm(drift_ou),
    dfx.ControlTerm(diffusion_ou, brownian_motion),
)

args = {"mu": 1.0, "sigma": 0.5, "tau": 0.01}

sol = dfx.diffeqsolve(
    sde_terms_ou,
    dfx.Heun(),
    t0,
    tmax,
    dt0=dt,
    y0=jnp.ones((2)) * args["mu"],
    saveat=dfx.SaveAt(dense=True),
    max_steps=5 * int((tmax - t0) * (1.0 / dt)),
    args=args,
)

tvec = np.arange(t0, tmax, dt)
y = np.array([sol.evaluate(t) for t in tvec])

plt.plot(tvec, y, linewidth=0.3)
np.corrcoef(y[:, 0], y[:, 1])[0, 1]

I got correlation 1 and plot looks like this: image

Since VirtualBrownianTree actually gives me different values, I am doing something wrong elsewhere, but I cannot find the bug.

Thanks!

patrick-kidger commented 1 year ago

This is because you're using ControlTerm, rather than WeaklyDiagonalControlTerm.

ControlTerm contracts together the diffusion and the Brownian motion. In this case the diffusion has shape (2,) and the BM has shape (2,) so their contraction has shape (). This scalar is then broadcast back up to shape (2,), i.e. the diffusion term is the same in both components.

To be precise: given a diffusion g of shape (a_1, ..., a_n, b_1, ..., b_m) and a BM W of shape (b_1, ..., b_m), then the resulting g dW term has shape (a_1, ..., a_n); this is computed via jnp.tensordot(g, dW, axes=dW.ndim).

In contrast WeaklyDiagonalControlTerm takes a diffusion g of shape (a_1, ..., a_n) and a BM W of shape (a_1, ..., a_n) and then just multiplies them together pointwise: g dW has shape (a_1, ..., a_n), and this is computed via g * dW.


If all of that is a bit hard to follow in generality, then consider the canonical case in which g is a matrix and dW is a vector, and g dW is a matrix-vector product. Then WeaklyDiagonalControlTerm handles the case that g is a diagonal matrix, whilst ControlTerm handles the case that g is an arbitrary matrix.


For what it's worth I can see that this is an easy mistake to me. If you have some thoughts on some additional checks we could add (that wouldn't compromise expressiveness if the user really does want to do this kind of thing) then I'd welcome any thoughts / any PRs. (Maybe we should prohibit broadcasting in certain scenarios?)

AHsu98 commented 4 months ago

I ran into this same confusion for a little bit. How would you feel about adding something like a ScaledControlTerm, where the vector field always a multiple of the identity matrix. Alternately, if there's some way to wrap that into the regular control term, where it broadcasts a scalar? (unsure about safety/feasibility of that?). This seems like a reasonably common use case, for example, for SDE sampling from a score based diffusion model.

If you think it's worthwhile, I could possibly take a crack at it sometime over the next couple weeks when I have a bit more time.

patrick-kidger commented 4 months ago

I can see the value in something like that.

For what it's worth I am a little concerned by the proliferation of different kinds of control terms! I can see this getting out of hand quickly.

I'm wondering if what we should do is extend ControlTerm to accept any kind of lineax.AbstractLinearOperator, that maps between the Brownian-motion-space and the y-space. This would give a user complete control over any kind of thing they might want, without needing to increase the complexity in Diffrax.

Then you could do something like

def diffusion(t, y, args):
    return lineax.IdentityLinearOperator()

diffrax.ControlTerm(diffusion, brownian_motion)
AHsu98 commented 4 months ago

I see, that makes sense. There's currently no way to specify the diffusion in a matrix free way, right (other than WeaklyDiagonalControlTerm)?

patrick-kidger commented 4 months ago

There is: subclass AbstractTerm. There's nothing special about the build-in ODETerm, ControlTerm etc. -- as an end-user you can create subclasses as well!

In particular if you're trying to avoid materialising the matrix prior to doing the diffusion-brownian contraction then you can implement the vf_prod method to do this. Also take a look at the terms page in the docs.

E.g. for reference: in what I'm suggesting above, given an operator returned from the vector field, then we would tweak ControlTerm.vf to return operator.as_matrix(), and ControlTerm.vf_prod to perform operator.mv.

Let me know if that makes sense or not, this is a fairly advanced API surface.

AHsu98 commented 4 months ago

That makes sense, thank you. This structure of vector field * control is really elegant, I've enjoyed working with diffrax and equinox a lot!