SciML / ModelingToolkit.jl

An acausal modeling framework for automatically parallelized scientific machine learning (SciML) in Julia. A computer algebra system for integrated symbolics for physics-informed machine learning and automated transformations of differential equations
https://mtk.sciml.ai/dev/
Other
1.41k stars 205 forks source link

Method for evaluating ODE's drift function at parameter set/variables #2555

Open TorkelE opened 6 months ago

TorkelE commented 6 months ago

Assuming that I have an ODE

dX/dt = p - k*X
dY/dt = k*X - d*Y

I create an ODESystem using

using ModelingToolkit
import ModelingToolkit: t_nounits as t, D_nounits as D
@variables X(t) Y(t)
@parameters p k d
eqs = [
    D(X) ~  p - k*X
    D(Y) ~  k*X - d*Y
]
@mtkbuild osys = ODESystem(eqs, t)

Now, if I want to simulate this, I can simply create an ODEProblem. However, it would be useful to have a way to evaluate the drift function for singular values. I.e. I want to learn the values of dX/dt and dY/dt at (X,Y) = (1.0, 1.0), (p,k,d) = (1.0, 2.0, 0.5). I.e. I have

u = [X => 1.0, Y => 1.0]
ps = [p => 1.0, k => 2.0, d => 0.5]

and expect an output

 [-2.0, 1.5]

Would it be possible to create a function for doing this?

Part of this would be to have similar ways to evaluate the diffusion function of SDESystems, and the NonlinearFunction of NonlinearSystems.

baggepinnen commented 6 months ago

The function is available as prob.f.f or something like that, but I don't think there is a user-facing utility that gives you this function, unfortunately. It would be useful to have, just as it would be useful to have a user-exposed way to obtain a function to compute observed variables given the state, time and parameters.

isaacsas commented 6 months ago

I think the issue isn't getting the function via prob.f.f, it is that with MTK9 users no longer seem to have an API-based way to construct the u and p objects to pass into this function unless they create an ODEProblem and access prob.u0 and prob.p. i.e. one used to be able to create the inputs using varmap_to_vars to created ordered vectors, but such functionality doesn't seem to exist anymore.

TorkelE commented 6 months ago

The problem is that that function takes input in the form of MTKParameters. I.e.

u = [X => 1.0, Y => 1.0]
ps = [p => 1.0, k => 2.0, d => 0.5]
t = 0.0
prob.f(u, ps, t)

gives very weird outputs.

TorkelE commented 6 months ago

I.e. trying

using ModelingToolkit
import ModelingToolkit.t_nounits as t
import ModelingToolkit.D_nounits as D
@variables X(t) Y(t) Z(t)
@parameters p1 d η
eq1 = D(X) ~ p1
eq2 = D(Y) ~ 0*d
eq3 = D(Z) ~ -η
@mtkbuild osys = ODESystem([eq1, eq2, eq3], t)

u0 = [X => 0.0, Y => 0.0, Z => 0.0]
ps = [p1 => 100.0, d => 1.0, η => 1.0]
oprob = ODEProblem(osys, u0, (0.0, 1.0), ps)

p_vals = ModelingToolkit.varmap_to_vars([p1 => 1.0, d => 2.0, η => 3.0], parameters(osys))
oprob.f([0.0, 0.0, 0.0], p_vals, 0.0)

One gets:

[3.0, -0.0, -1.0]

Which is wrong. I.e. η should have a negative value, however, its output (3.0) appears positive.

baggepinnen commented 6 months ago

You could use p = oprob.p if you need something hacky right now. But I agree, it would be beneficial to have API for all of these things.

ChrisRackauckas commented 6 months ago

Though I will mention that 90% of what ODEProblem does is the construction of f, u0, and p so I'm a little confused as to what you think you're saving.

baggepinnen commented 6 months ago

I guess the fields of ODEProblem aren't considered public API?

ChrisRackauckas commented 6 months ago

They are

ChrisRackauckas commented 6 months ago

And the implementation of the solution here is to just use ODEProblem and the public API on the fields there, u0, p, and tspan[1], to do this.

TorkelE commented 6 months ago

I still think there is an advantage with isolating the actual call part, rather than creating a ODEProblem and then making calls to remake whenever one wants to try a new set of values. Right, you might only remove like one line of call per call, but it does make the code much more readable for people less familiar with what is actually going on underneath.

ChrisRackauckas commented 6 months ago

It would only remove one line, because doing this is currently 2 lines.

TorkelE commented 6 months ago

Thanks. Yeah, I wrote like one line of call per call because I figured it would only be one line, but wasn't 100% sure that was the case, and you are right that that is indeed the case.