patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.36k stars 124 forks source link

Solving simple dynamics: ControlTerm piecewise product #358

Open hsimonfroy opened 7 months ago

hsimonfroy commented 7 months ago

Hi, Thanks first for developing this nice package.

For the context, I intend to use diffrax to implement a custom Langevin-like dynamic, but my issue can be reduced to the following. Let's say I want to implement a simple $n$-dimensional Brownian motion: $$dX = dB$$

I can try doing

diffusion = lambda t, y, args: jnp.ones(n)
brownian_motion = VirtualBrownianTree(t0, t1, tol=dt/2, shape=(n,), key=seed)
terms = ControlTerm(diffusion, brownian_motion)

but it wouldn't work because of this line defining the vf-contr product for ControlTerms. What tensordot(vf, contr, axes=ndim(contr)) does is fully contracting tensors (on all the dimensions of contr), so in my case it would return a scalar $n dB$, whereas I would generally require piecewise (Hadamard) product vf * contr.

For now, the only way I found to implement piecewise ControlTerm product is to increase the dimensionality of the vector field, e.g. in that case, write diffusion = lambda t, y, args: jnp.eye(n), which is way more expensive ( $O(n^2)$ ) and will not scale to my applications. And I am not sure some jax.experimental.sparse matrices would help.

I understand matrix product, nay higher rank tensor products, may be required in some applications. This recent question, or this diffrax example of Neural SDE, have both matrix-valued diffusion vector field and vector-valued Brownian control. However, if I am not wrong, it seems that for that same reason of full tensor contraction, matrix product between matrix-valued vf and matrix-valued control is not currently easily implemented.

So my question would be: Did I miss a way to implement ControlTerm piecewise product? I think it should be possible to implement it:

I could not think of any einsum to replace tensordot(a,b,ndim(b)) that would fit well in all cases, but maybe having a way to specify which product _prod function to use in ControlTerm could be an idea? Or maybe I just missed a simple way to do everything above.

Thanks in advance!

patrick-kidger commented 7 months ago

You want diffrax.WeaklyDiagonalControlTerm instead of just diffrax.ControlTerm. :)

maybe having a way to specify which product _prod function to use in ControlTerm could be an idea?

For this more general case, you can subclass diffrax.AbstractTerm and then implement the appropriate product you have in mind. (Just like the built-in terms!)

hsimonfroy commented 6 months ago

Thanks, works fine!

The getting started SDE part redirects to Terms page so I should have seen it ;)

Also concerning the Brownian control, it seems changes in VirtualBrownianTree make diffrax not supporting reverse-time SDEs in new 0.5.0 version anymore. One get t0 must be strictly less than t1 error (and reversing t0 and t1 in the call does not help), whereas reverse-time ODE still works fine.

patrick-kidger commented 6 months ago

Ah, interesting point about the Brownian motion. FWIW since it's just a control then I think it shouldn't matter too much -- just switch them before passing them to the control. That said I'd be happy to add a PR that makes this "just work". (Ideally negating the generated samples if t0 > t1.)