Open aspannaus opened 2 months ago
The classic (https://github.com/patrick-kidger/diffrax/issues/446#issuecomment-2187405940) strikes once again 😉
It seems like there are a few errors here. First, you return a diagonal, but control term is for full matrices by default, so you need to fix that (with a DiagonalOperator). Second, SEA requires a SpaceTimeLevy area (this should go in the solver docs imo). Finally, SEA requires additive noise (i.e. g is not a function of x) so you can't use this solver with that noise term.
Using all three tricks you get something that works and looks like:
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import diffrax
import lineax as lx
def sde_drift(t, y, args):
N, _ = args
beta = jnp.exp(y[3])
gamma = jnp.exp(y[4])
dS = -(beta * y[0] * y[1]) / N
dI = (beta * y[0] * y[1]) / N - y[1] * gamma
dR = y[1] * gamma
# only diffusion, no drift
dbeta = 0.0 # jnp.array([0.0])
dgamma = 0.0 # jnp.array([0.0])
dy = jnp.array([dS, dI, dR, dbeta, dgamma])
return dy
def sde_diffusion(t, y, args):
_, sigma_1 = args
y1, y2, y3, y4, y5 = y
diagonal = jnp.array([0.0, 0.0, 0.0, sigma_1 * y4, sigma_1 * y5])
return lx.DiagonalLinearOperator(diagonal)
def sde():
t0 = 0
t1 = 100
dt0 = 0.1
y0 = jnp.array([3990.0, 10.0, 0.01, jnp.log(0.25), jnp.log(0.05)])
args = (4000.0, 0.2)
bm = diffrax.VirtualBrownianTree(t0, t1, tol=1e-2, shape=(5,), key=jr.PRNGKey(42), levy_area=diffrax.SpaceTimeLevyArea)
terms = diffrax.MultiTerm(diffrax.ODETerm(sde_drift), diffrax.ControlTerm(sde_diffusion, bm))
solver = diffrax.GeneralShARK()
saveat = diffrax.SaveAt(dense=True)
print(type(terms))
sol = diffrax.diffeqsolve(terms, solver, t0, t1, dt0=dt0, y0=y0, args=args, saveat=saveat)
print(sol)
sde()
Thanks for the reply; I must have missed some of the points about the solver you make in the docs.
Trying the code you suggested, I get the error ValueError: Custom node type mismatch: expected type: <class 'lineax._operator.DiagonalLinearOperator'>, value: Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace(level=2/0)>.
I had tried this previously without success, but perhaps it is correct and there's something behind the scenes happening?
For completeness, here's the library versions I'm using:
Thanks again for the assistance.
Yes, I was using diffrax 0.6.0
That was it, thanks again!
Hi, I had the same issue as above. `
ValueError: Custom node type mismatch: expected type: <class `'lineax._operator.DiagonalLinearOperator'>``
I updated all the packages to the versions above abd I get the error:
AttributeError: module 'opt_einsum' has no attribute 'paths'
Do you have any ideas?
What versions are you using?
Hi Owen, i fixed the issue by setting up a new conda environment and made sure jax, jaxlib,equinox, lineax and diffrax through pip and not conda (where diffrax 0.6.0 is not yet availble). I don't really know what the underlying issue was. But thanks anyway.
Hi all,
thanks for the great library. I'm having an issue implementing a coupled system of SDEs. I'm getting an
ValueError: `terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`)
error. The system is:The code is
Printing the
type
ofterms
yields'diffrax._term.MultiTerm
, so I'm not entirely sure where to look. What can you suggest to look at?Thanks in advance.