patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.37k stars 124 forks source link

Coupled SDE System Implementation #461

Open aspannaus opened 2 months ago

aspannaus commented 2 months ago

Hi all,

thanks for the great library. I'm having an issue implementing a coupled system of SDEs. I'm getting an ValueError: `terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`) error. The system is:

\begin{aligned}
        \frac{\mathrm{d} S(t)}{\mathrm{d} t} &= -\beta(t)S(t)\frac{I(t)}{N} \mathrm{d} t, \\
        \frac{\mathrm{d} I(t)}{\mathrm{d}t} &= (\beta(t)S(t)\frac{I(t)}{N} - \gamma(t) I(t)) \mathrm{d} t,\\
        \frac{\mathrm{d} R(t)}{\mathrm{d}t} &= \gamma(t) I(t)\, \mathrm{d} t,\\
        \frac{\mathrm{d} \log\beta(t)}{\mathrm{d}t} &= w_3\mathrm{d} B_w(t),\\
        \frac{\mathrm{d} \log\gamma(t)}{\mathrm{d}t} &= u_3 \mathrm{d}B_u(t)
\end{aligned}

The code is


import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

import diffrax

def sde_drift(t, y, args):
    N, _ = args
    beta = jnp.exp(y[3])
    gamma = jnp.exp(y[4])
    dS = -(beta * y[0] * y[1]) / N
    dI = (beta * y[0] * y[1]) / N - y[1] * gamma
    dR = y[1] * gamma
    # only diffusion, no drift
    dbeta = 0.0  # jnp.array([0.0])
    dgamma = 0.0  # jnp.array([0.0])
    dy = jnp.array([dS, dI, dR, dbeta, dgamma])

    return dy

def sde_diffusion(t, y, args):
    _, sigma_1 = args
    y1, y2, y3, y4, y5 = y
    diagonal = jnp.array([0.0, 0.0, 0.0, sigma_1 * y4, sigma_1 * y5])
    return diagonal 

def sde():

    t0 = 0
    t1 = 100
    dt0 = 0.1
    y0 = jnp.array([3990.0, 10.0, 0.01, jnp.log(0.25), jnp.log(0.05)])
    args = (4000.0, 0.2)

    bm = diffrax.VirtualBrownianTree(t0, t1, tol=1e-2, shape=(5,), key=jr.PRNGKey(42))
    terms = diffrax.MultiTerm(diffrax.ODETerm(sde_drift), diffrax.ControlTerm(sde_diffusion, bm))
    solver = diffrax.SEA()
    saveat = diffrax.SaveAt(dense=True)

    print(type(terms))

    sol = diffrax.diffeqsolve(terms, solver, t0, t1, dt0=dt0, y0=y0, args=args, saveat=saveat)
    print(sol)

Printing the type of terms yields 'diffrax._term.MultiTerm, so I'm not entirely sure where to look. What can you suggest to look at?

Thanks in advance.

lockwo commented 2 months ago

The classic (https://github.com/patrick-kidger/diffrax/issues/446#issuecomment-2187405940) strikes once again 😉

It seems like there are a few errors here. First, you return a diagonal, but control term is for full matrices by default, so you need to fix that (with a DiagonalOperator). Second, SEA requires a SpaceTimeLevy area (this should go in the solver docs imo). Finally, SEA requires additive noise (i.e. g is not a function of x) so you can't use this solver with that noise term.

Using all three tricks you get something that works and looks like:


import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

import diffrax
import lineax as lx

def sde_drift(t, y, args):
    N, _ = args
    beta = jnp.exp(y[3])
    gamma = jnp.exp(y[4])
    dS = -(beta * y[0] * y[1]) / N
    dI = (beta * y[0] * y[1]) / N - y[1] * gamma
    dR = y[1] * gamma
    # only diffusion, no drift
    dbeta = 0.0  # jnp.array([0.0])
    dgamma = 0.0  # jnp.array([0.0])
    dy = jnp.array([dS, dI, dR, dbeta, dgamma])

    return dy

def sde_diffusion(t, y, args):
    _, sigma_1 = args
    y1, y2, y3, y4, y5 = y
    diagonal = jnp.array([0.0, 0.0, 0.0, sigma_1 * y4, sigma_1 * y5])
    return lx.DiagonalLinearOperator(diagonal) 

def sde():

    t0 = 0
    t1 = 100
    dt0 = 0.1
    y0 = jnp.array([3990.0, 10.0, 0.01, jnp.log(0.25), jnp.log(0.05)])
    args = (4000.0, 0.2)

    bm = diffrax.VirtualBrownianTree(t0, t1, tol=1e-2, shape=(5,), key=jr.PRNGKey(42), levy_area=diffrax.SpaceTimeLevyArea)
    terms = diffrax.MultiTerm(diffrax.ODETerm(sde_drift), diffrax.ControlTerm(sde_diffusion, bm))
    solver = diffrax.GeneralShARK()
    saveat = diffrax.SaveAt(dense=True)

    print(type(terms))

    sol = diffrax.diffeqsolve(terms, solver, t0, t1, dt0=dt0, y0=y0, args=args, saveat=saveat)
    print(sol)

sde()
aspannaus commented 2 months ago

Thanks for the reply; I must have missed some of the points about the solver you make in the docs.

Trying the code you suggested, I get the error ValueError: Custom node type mismatch: expected type: <class 'lineax._operator.DiagonalLinearOperator'>, value: Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace(level=2/0)>. I had tried this previously without success, but perhaps it is correct and there's something behind the scenes happening?

For completeness, here's the library versions I'm using:

Thanks again for the assistance.

lockwo commented 2 months ago

Yes, I was using diffrax 0.6.0

aspannaus commented 2 months ago

That was it, thanks again!

SoerenNagel commented 3 weeks ago

Hi, I had the same issue as above. `

ValueError: Custom node type mismatch: expected type: <class `'lineax._operator.DiagonalLinearOperator'>``

I updated all the packages to the versions above abd I get the error:

AttributeError: module 'opt_einsum' has no attribute 'paths'

Do you have any ideas?

lockwo commented 3 weeks ago

What versions are you using?

SoerenNagel commented 3 weeks ago

Hi Owen, i fixed the issue by setting up a new conda environment and made sure jax, jaxlib,equinox, lineax and diffrax through pip and not conda (where diffrax 0.6.0 is not yet availble). I don't really know what the underlying issue was. But thanks anyway.