patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.46k stars 133 forks source link

Example for solving forced ODE #56

Open fhchl opened 2 years ago

fhchl commented 2 years ago

Very cool work!

It would be great to have an example on how to solve an ODE that is directly forced by some signal x(t), e.g. a forced mass-spring-damper

m y'' + r y' + k y = x(t).

Do I understand correctly, that the controlled ODEs are "forced" by the derivative of x(t)?

patrick-kidger commented 2 years ago

Right, so "controlled differential equations" are a specific notion, and are written like dy(t) = f(y(t)) dx(t); in some sense they are indeed forced by the derivative of x. This is a concept coming out of rough path theory, and is particularly interesting for the case where x(t) = [t, w(t)] is a time-augmented Brownian motion. Indeed you've probably see stuff like dy(t) = μ(t, y(t)) dt + σ(t, y(t)) dw(t), i.e. an SDE, written down before.


If you want it to be directly forced as you describe, then the answer is to encode this time dependence directly into the vector field:

def vector_field(t, y, args):
    signal = x(t)
    ...
term = diffrax.ODETerm(vector_field)
diffrax.diffeqsolve(term, ...)

If you think this would be valuable and would be happy to add it, then I'd be very happy to accept a PR adding a short example on this. (e.g. using the Stiff ODE example as a starting point.) I definitely recognise that this kind of direct forcing appears more frequently in e.g. the engineering literature.

patrick-kidger commented 2 years ago

(If you're curious to know more about CDEs than I'd recommend Chapter 3 of the recently-released On Neural Differential Equations for an introduction. There's other references available too; I just happen to have written that one recently ;) )

fhchl commented 2 years ago

I'll definitely have a look at the thesis, thanks!

If you want it to be directly forced as you describe, then the answer is to encode this time dependence directly into the vector field

The most straightforward way would be to include the forcing term in vector_field - if one had a functional representation of it. Let's say that the forcing signal was measured so it is given as some collection of samples. One could use some of the interpolation schemes to get the functional representation and then include that into vector_field.

Do I understand it correctly, that this would also be an option for the CDE, in case the control is differentiable? Something along

def vector_field(t, y, args):
   dx = control.derivative(t)
    ...
term = diffrax.ODETerm(vector_field)
diffrax.diffeqsolve(term, ...)

Is it correct to assume then, that the ControlTerm abstraction is especially useful for the SDEs?

patrick-kidger commented 2 years ago

The most straightforward way would be to include the forcing term in vector_field - if one had a functional representation of it. Let's say that the forcing signal was measured so it is given as some collection of samples. One could use some of the interpolation schemes to get the functional representation and then include that into vector_field.

Yep! So you end up with

def vector_field(t, y, args):
    x = control.evaluate(t)
    ...

Do I understand it correctly, that this would also be an option for the CDE, in case the control is differentiable? Something along

Yep. In fact this is already built into Diffrax and wouldn't need to be (re)implemented manually. See ControlTerm.to_ode.

Is it correct to assume then, that the ControlTerm abstraction is especially useful for the SDEs?

I think the two main uses are (a) SDEs and (b) CDEs done via dx rather than x.

(FWIW I don't know of a strong reason to prefer doing CDEs one way over the other. I know of a short list of advantages and disadvantages, but none of them are that big of a deal.)

fhchl commented 2 years ago

Got it!

I have been using Jax so far to fit physical ODE models. I could cook up a little example for that use case with the approach discussed above.

joglekara commented 2 years ago

The most straightforward way would be to include the forcing term in vector_field - if one had a functional representation of it. Let's say that the forcing signal was measured so it is given as some collection of samples. One could use some of the interpolation schemes to get the functional representation and then include that into vector_field.

Yep! So you end up with

def vector_field(t, y, args):
    x = control.evaluate(t)
    ...

Great work with the library!

I have a use-case where I have an ODE w/ a (static) parameterized forcing function. I wanted to double check a few things

Thanks in advance!

patrick-kidger commented 2 years ago

Yep, that's exactly right.