patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.42k stars 127 forks source link

Can diffrax solve forward-backward SDE? #166

Closed luweizheng closed 2 years ago

luweizheng commented 2 years ago

Hi patrick,

Just want to know how to use diffrax to solve forward-backward SDE like this which proposed in this paper:

image

In these equations, there are 3 evolving states in the system: X, Y and Z. Currently, the doc says that the vector field which is a callable takes three arguments (t, y, args). t is a scalar representing the integration time. y is the evolving state of the system.args are any static arguments.

anh-tong commented 2 years ago

I think your question is missing some parts here. I just take a look at the paper, it seems there are two more dependencies

  1. $Y_t = u(t, X_t)$
  2. $Z_t = Du(t, X_t)$

The paper says that $u$ is modeled by a neural work so we can make one as

import equinox as eqx

class U(eqx.Module):
     # your implemetation here

u = U()

The function $Z(t,X_t)$ is the derivative of $u(t,X_t)$, I guess this will be like (not so sure though)

import jax

def Z(t, X):
    return jax.grad(lamda x: u(t, x))(X)

The paper will solve a system of differential equations with two variables $X_t$ and $Y_t$. The argument y in vector fields can be PyTree (including tuple, list, dictionary, and nest of them). In this case, you can make y as a tuple of $X_t$ and $Y_t$.

For examples, the drift term will look like this

def vector_field(t, y, args):
    X = y[0]   # $X_t$
    Y = y[1]   # $Y_t$
    Z_out = Z(t, X)
    mu_out = mu(t, X, Y, Z_out)    # you may need to define function mu and phi
    phi_out = phi(t, X, Y, Z_out)
    # return a tuple as well: the first correponds to X, the second corresponds to Y
    return (mu_out, phi_out)

You can do the same thing for the diffusion term.

I don't read the paper in detail but probably the two SDEs share the same Brownian path so you may be careful here. That you may need tweak Brownian paths.

Note that the intial condition of $Y_0 = u(0, X_0)$, so when you set y0 of diffrax.diffeqsolve, make sure to input a tuple $(X_0, u(0, X_0))$

Hope it could help :).

luweizheng commented 2 years ago

@anh-tong Previously I thought y should be a Scalar. Your reply really helps! It reminds me that y can b a tuple of different states.

luweizheng commented 2 years ago

@anh-tong Hi Anh,

Your demo code almost solves my problem. Now the only thing that I don't solve is the control term of the $Y$. The $X_t$ is a N dimension tensor and $Y_t$ is a scalar (or 1 dimension array) in my case. $Y_t$ is the result of

$$ Y_t = u(t, X_t) $$

The vector field you provided is ok. But for the control field, both the $X_t$ state and $Y_t$ state depend on the same brownian path (N-dimension array). Here is what I do to handle the dimension of $X_t$ and $Y_t$ without diffrax using euler method:

x1 = x0 + mu_fn(curr_t, x0, y0, z0) * dt + \
        sigma_fn(curr_t, x0, y0) * dW
y1_tilde = y0 + phi_fn(curr_t, x0, y0, z0) * dt + \
        jnp.sum(z0 * sigma_fn(curr_t, x0, y0) * dW, keepdims=True)

Here z0 is the Z_out in your demo code and the control term is jnp.sum(z0 * sigma_fn * dW). z0, sigma_fn and dW are N-dim array. jnp.sum reduce the N dimension into 1 dimension so the y1_tilde is 1-dim array.

So my question is, how to add jnp.sum to the control term. ControlTerm seems to be a PyTree not a Callable function. jnp.sum(ControlTerm) does not work.

anh-tong commented 2 years ago

In this case, you need to implement a new ControlTerm.

Let us make a diffusion function as following (not 100% like your code block)

def sigma_fn(t, x0, y0):
     # your implemetation
     # return vector size N

def diffusion(t, y, args):
     # y is a tuple
     x0 = y[0]    # size N
     y0 = y[1]    # scalar
     z_0 = Z(t, x0)
     ret1 = sigma_fn(t, x0, y0)
     ret2 = z_0 * sigma_fn(t, x0, y0)    # a vector size N
     # return tuple (size N, size N) 
     return (ret1, ret2)

You may need to implement ControlTerm, in particular the function prod which describes how you make the product between the vector field (of diffusion term) and the Brownian motion path.

class CustomControlTerm(diffrax.ControlTerm):

     @staticmethod
     def prod(vf: PyTree, control: PyTree):
         # `vf` is the result of the above diffusion function
         first = vf[0]    # the first element (size N)
         second = vf[1]   # the second .... (size N)

         # control here is a N-dimensional Brownian path as a N-dimensional array
         # control = dW
         prod1 = first * control                       # size N
         prod2 = jnp.sum(second * control)   # size 1

         # this return a Pytree as a tuple (size N, scalar)
         return (prod1, prod2)

Finally, the new ControlTerm will be constructed like this

control = BrownianPath(shape=(N,), key=...)
diffusion_term = CustomControlTerm(diffusion, control)

That would be it. Let me know if you have more questions.