Open fhchl opened 2 years ago
Right, so "controlled differential equations" are a specific notion, and are written like dy(t) = f(y(t)) dx(t)
; in some sense they are indeed forced by the derivative of x
. This is a concept coming out of rough path theory, and is particularly interesting for the case where x(t) = [t, w(t)]
is a time-augmented Brownian motion. Indeed you've probably see stuff like dy(t) = μ(t, y(t)) dt + σ(t, y(t)) dw(t)
, i.e. an SDE, written down before.
If you want it to be directly forced as you describe, then the answer is to encode this time dependence directly into the vector field:
def vector_field(t, y, args):
signal = x(t)
...
term = diffrax.ODETerm(vector_field)
diffrax.diffeqsolve(term, ...)
If you think this would be valuable and would be happy to add it, then I'd be very happy to accept a PR adding a short example on this. (e.g. using the Stiff ODE example as a starting point.) I definitely recognise that this kind of direct forcing appears more frequently in e.g. the engineering literature.
(If you're curious to know more about CDEs than I'd recommend Chapter 3 of the recently-released On Neural Differential Equations for an introduction. There's other references available too; I just happen to have written that one recently ;) )
I'll definitely have a look at the thesis, thanks!
If you want it to be directly forced as you describe, then the answer is to encode this time dependence directly into the vector field
The most straightforward way would be to include the forcing term in vector_field
- if one had a functional representation of it. Let's say that the forcing signal was measured so it is given as some collection of samples. One could use some of the interpolation schemes to get the functional representation and then include that into vector_field
.
Do I understand it correctly, that this would also be an option for the CDE, in case the control is differentiable? Something along
def vector_field(t, y, args):
dx = control.derivative(t)
...
term = diffrax.ODETerm(vector_field)
diffrax.diffeqsolve(term, ...)
Is it correct to assume then, that the ControlTerm
abstraction is especially useful for the SDEs?
The most straightforward way would be to include the forcing term in vector_field - if one had a functional representation of it. Let's say that the forcing signal was measured so it is given as some collection of samples. One could use some of the interpolation schemes to get the functional representation and then include that into vector_field.
Yep! So you end up with
def vector_field(t, y, args):
x = control.evaluate(t)
...
Do I understand it correctly, that this would also be an option for the CDE, in case the control is differentiable? Something along
Yep. In fact this is already built into Diffrax and wouldn't need to be (re)implemented manually. See ControlTerm.to_ode
.
Is it correct to assume then, that the ControlTerm abstraction is especially useful for the SDEs?
I think the two main uses are (a) SDEs and (b) CDEs done via dx
rather than x
.
(FWIW I don't know of a strong reason to prefer doing CDEs one way over the other. I know of a short list of advantages and disadvantages, but none of them are that big of a deal.)
Got it!
I have been using Jax so far to fit physical ODE models. I could cook up a little example for that use case with the approach discussed above.
The most straightforward way would be to include the forcing term in vector_field - if one had a functional representation of it. Let's say that the forcing signal was measured so it is given as some collection of samples. One could use some of the interpolation schemes to get the functional representation and then include that into vector_field.
Yep! So you end up with
def vector_field(t, y, args): x = control.evaluate(t) ...
Great work with the library!
I have a use-case where I have an ODE w/ a (static) parameterized forcing function. I wanted to double check a few things
args
and have a eval = forcing_function(args, t)
call. That seems perfectly reasonable given the above discussion.args
is explicitly in the function call. Is that the case? Is the vector_field
differentiable wrt args
? diffeqsolve
wrt args
, correct?Thanks in advance!
Yep, that's exactly right.
Very cool work!
It would be great to have an example on how to solve an ODE that is directly forced by some signal
x(t)
, e.g. a forced mass-spring-damperm y'' + r y' + k y = x(t)
.Do I understand correctly, that the controlled ODEs are "forced" by the derivative of
x(t)
?