Closed luweizheng closed 2 years ago
I think your question is missing some parts here. I just take a look at the paper, it seems there are two more dependencies
The paper says that $u$ is modeled by a neural work so we can make one as
import equinox as eqx
class U(eqx.Module):
# your implemetation here
u = U()
The function $Z(t,X_t)$ is the derivative of $u(t,X_t)$, I guess this will be like (not so sure though)
import jax
def Z(t, X):
return jax.grad(lamda x: u(t, x))(X)
The paper will solve a system of differential equations with two variables $X_t$ and $Y_t$. The argument y
in vector fields can be PyTree (including tuple, list, dictionary, and nest of them). In this case, you can make y
as a tuple of $X_t$ and $Y_t$.
For examples, the drift term will look like this
def vector_field(t, y, args):
X = y[0] # $X_t$
Y = y[1] # $Y_t$
Z_out = Z(t, X)
mu_out = mu(t, X, Y, Z_out) # you may need to define function mu and phi
phi_out = phi(t, X, Y, Z_out)
# return a tuple as well: the first correponds to X, the second corresponds to Y
return (mu_out, phi_out)
You can do the same thing for the diffusion term.
I don't read the paper in detail but probably the two SDEs share the same Brownian path so you may be careful here. That you may need tweak Brownian paths.
Note that the intial condition of $Y_0 = u(0, X_0)$, so when you set y0
of diffrax.diffeqsolve
, make sure to input a tuple $(X_0, u(0, X_0))$
Hope it could help :).
@anh-tong Previously I thought y
should be a Scalar. Your reply really helps! It reminds me that y
can b a tuple of different states.
@anh-tong Hi Anh,
Your demo code almost solves my problem. Now the only thing that I don't solve is the control term of the $Y$. The $X_t$ is a N dimension tensor and $Y_t$ is a scalar (or 1 dimension array) in my case. $Y_t$ is the result of
$$ Y_t = u(t, X_t) $$
The vector field you provided is ok. But for the control field, both the $X_t$ state and $Y_t$ state depend on the same brownian path (N-dimension array). Here is what I do to handle the dimension of $X_t$ and $Y_t$ without diffrax using euler method:
x1 = x0 + mu_fn(curr_t, x0, y0, z0) * dt + \
sigma_fn(curr_t, x0, y0) * dW
y1_tilde = y0 + phi_fn(curr_t, x0, y0, z0) * dt + \
jnp.sum(z0 * sigma_fn(curr_t, x0, y0) * dW, keepdims=True)
Here z0
is the Z_out
in your demo code and the control term is jnp.sum(z0 * sigma_fn * dW)
. z0
, sigma_fn
and dW
are N-dim array. jnp.sum
reduce the N dimension into 1 dimension so the y1_tilde
is 1-dim array.
So my question is, how to add jnp.sum
to the control term. ControlTerm seems to be a PyTree not a Callable function. jnp.sum(ControlTerm)
does not work.
In this case, you need to implement a new ControlTerm
.
Let us make a diffusion function as following (not 100% like your code block)
def sigma_fn(t, x0, y0):
# your implemetation
# return vector size N
def diffusion(t, y, args):
# y is a tuple
x0 = y[0] # size N
y0 = y[1] # scalar
z_0 = Z(t, x0)
ret1 = sigma_fn(t, x0, y0)
ret2 = z_0 * sigma_fn(t, x0, y0) # a vector size N
# return tuple (size N, size N)
return (ret1, ret2)
You may need to implement ControlTerm
, in particular the function prod
which describes how you make the product between the vector field (of diffusion term) and the Brownian motion path.
class CustomControlTerm(diffrax.ControlTerm):
@staticmethod
def prod(vf: PyTree, control: PyTree):
# `vf` is the result of the above diffusion function
first = vf[0] # the first element (size N)
second = vf[1] # the second .... (size N)
# control here is a N-dimensional Brownian path as a N-dimensional array
# control = dW
prod1 = first * control # size N
prod2 = jnp.sum(second * control) # size 1
# this return a Pytree as a tuple (size N, scalar)
return (prod1, prod2)
Finally, the new ControlTerm
will be constructed like this
control = BrownianPath(shape=(N,), key=...)
diffusion_term = CustomControlTerm(diffusion, control)
That would be it. Let me know if you have more questions.
Hi patrick,
Just want to know how to use diffrax to solve forward-backward SDE like this which proposed in this paper:
In these equations, there are 3 evolving states in the system: X, Y and Z. Currently, the doc says that the vector field which is a callable takes three arguments
(t, y, args)
.t
is a scalar representing the integration time.y
is the evolving state of the system.args
are any static arguments.