SciML / ModelingToolkit.jl

An acausal modeling framework for automatically parallelized scientific machine learning (SciML) in Julia. A computer algebra system for integrated symbolics for physics-informed machine learning and automated transformations of differential equations
https://mtk.sciml.ai/dev/
Other
1.41k stars 204 forks source link

Performant way to remake ODEProblem with updated external input function? #2646

Closed hersle closed 2 weeks ago

hersle commented 5 months ago

I have an ODE for x(t). It depends on a parameter P not only explicitly, but also implicitly through an external input (forcing) function F(t). I want to efficiently solve the ODE repeatedly for many values of P, always updating F(t) with the proper dependence on P. Eventually, I also want to autodifferentiate the solution with respect to P, and that this accurately propagates through F(t).

Here is the closest I have come, trying to use remake on a simple example:

using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D
using DifferentialEquations
using DataInterpolations
using ForwardDiff

P0 = 2.0 # test value for the parameter P

# F(t) is an external function that depends on the parameter P
function create_F(P)
    ts = range(0.0, 1.0, step=0.1)
    Fs = ts .^ P # toy example: F = t^P
    return QuadraticSpline(ts, Fs)
end
F_spline = create_F(NaN) # uninitialized spline (replace NaN -> P0 for code to run)
F(t) = F_spline(t)

# an ODE that depends on the parameter P both explicitly, and implicitly through F(t)
@parameters P
@variables x(t)
@register_symbolic F(t) # following https://docs.sciml.ai/ModelingToolkit/stable/tutorials/ode_modeling/#Specifying-a-time-variable-forcing-function
sys = structural_simplify(ODESystem([D(x) ~ P * F(t)], t, [x], [P]; name = :sys))
prob = ODEProblem(sys, [x => 0.0], (0.0, 1.0), [P => NaN]) # uninitialized problem

# solution of the ODE for a given value of the parameter P (must be fast!)
function solve_instance_fast(P)
    F_instance = create_F(P)
    prob_instance = remake(prob; p = [P]) # QUESTION: how to update F in prob to F_instance?
    return solve(prob_instance)
end

ForwardDiff.derivative(P -> solve_instance_fast(P)(1.0; idxs=x), P0)

Is there a way to accomplish this with ModelingToolkit?

ChrisRackauckas commented 5 months ago

Why not just make the spline the parameter and update that?

hersle commented 5 months ago

That works! Here is a working example:

using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D
using DifferentialEquations
using DataInterpolations
using ForwardDiff

P0 = 2.0 # test value for the parameter P

# F(t) is an external function that depends on the parameter P
function create_F(P)
    ts = range(0.0, 1.0, step=0.1)
    Fs = ts .^ P # toy example: F = t^P
    return CubicSpline(Fs, ts)
end
F(t, spline) = spline(t) # intermediate function for F(t) (possible to avoid this?)

# an ODE that depends on the parameter P both explicitly, and implicitly through F(t)
@parameters P F_spline
@variables x(t)
@register_symbolic F(t, spline) # following https://docs.sciml.ai/ModelingToolkit/stable/tutorials/ode_modeling/#Specifying-a-time-variable-forcing-function
sys = structural_simplify(ODESystem([D(x) ~ P * F(t, F_spline)], t, [x], [P, F_spline]; name = :sys))
prob = ODEProblem(sys, [x => 0.0], (0.0, 1.0), [P => NaN, F_spline => NaN]) # uninitialized problem

# solution of the ODE for a given value of the parameter P (must be fast!)
function solve_instance_fast(P)
    F_instance = create_F(P)
    prob_instance = remake(prob; p = [sys.P => P, sys.F_spline => F_instance]) # update P and spline for F(t)
    return solve(prob_instance)
end

x_at_1(P) = solve_instance_fast(P)(1.0; idxs=x) # analytical result: x(1) = P/(P+1)
ForwardDiff.derivative(x_at_1, P0) # analytical result: d(x(t=1))/dP = 1/(P+1)^2

Thank you! I thought I was perhaps dealing with a restriction of MTK, and this simple solution just did not cross my mind.

cstjean commented 2 weeks ago

With MTK 9.32.0, it works until the last line, which now yields:

TypeError: in validate_parameter_type, in Parameter F_spline, expected Real, got a value of type DataInterpolations.CubicSpline{ReadOnlyArrays.ReadOnlyVector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.var"workspace#8".x_at_1), Float64}, Float64, 1}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.var"workspace#8".x_at_1), Float64}, Float64, 1}}}, ReadOnlyArrays.ReadOnlyVector{Float64, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.var"workspace#8".x_at_1), Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.var"workspace#8".x_at_1), Float64}, Float64, 1}}, Vector{Float64}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.var"workspace#8".x_at_1), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.var"workspace#8".x_at_1), Float64}, Float64, 1}}

Stack trace
Here is what happened, the most recent locations are first:

validate_parameter_type(ic::ModelingToolkit.IndexCache, p::Symbolics.Num, index::ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Int64}, val::DataInterpolations.CubicSpline{ReadOnlyArrays.ReadOnlyVector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.var"workspace#8".x_at_1), Float64}, Float64, 1}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.var"workspace#8".x_at_1), Float64}, Float64, 1}}}, ReadOnlyArrays.ReadOnlyVector{Float64, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.var"workspace#8".x_at_1), Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.var"workspace#8".x_at_1), Float64}, Float64, 1}}, Vector{Float64}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.var"workspace#8".x_at_1), Float64}, Float64, 1}}, ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.var"workspace#8".x_at_1), Float64}, Float64, 1}}) @ parameter_buffer.jl:545
remake_buffer(indp::SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ModelingToolkit.MTKParameters{Vector{Float64}, StaticArraysCore.SizedVector{0, Any, Vector{Any}}, Tuple{}, Tuple{}, Tuple{}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x33b3bda3, 0xd089b82e, 0xfedcd84c, 0x4037599c, 0x2848c6f6), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1,), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xb8b89fef, 0x1b0590a0, 0x8801a652, 0xfeb6172b, 0xb8db857d), Nothing}}, SciMLBase.ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#781"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x554580b9, 0x9d2ea56b, 0xc07459a0, 0xc038f528, 0x08cfb40f), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xa41a05c1, 0xcfaffdc7, 0x2760ce99, 0x8c689083, 0x4b780ddd), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ModelingToolkit.ODESystem}, Nothing, ModelingToolkit.ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, oldbuf::ModelingToolkit.MTKParameters{Vector{Float64}, StaticArraysCore.SizedVector{0, Any, Vector{Any}}, Tuple{}, Tuple{}, Tuple{}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x33b3bda3, 0xd089b82e, 0xfedcd84c, 0x4037599c, 0x2848c6f6), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1,), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xb8b89fef, 0x1b0590a0, 0x8801a652, 0xfeb6172b, 0xb8db857d), Nothing}}, vals::Dict{Any, Any}) @ parameter_buffer.jl:588
_updated_u0_p_symmap(prob::SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ModelingToolkit.MTKParameters{Vector{Float64}, StaticArraysCore.SizedVector{0, Any, Vector{Any}}, Tuple{}, Tuple{}, Tuple{}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x33b3bda3, 0xd089b82e, 0xfedcd84c, 0x4037599c, 0x2848c6f6), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1,), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xb8b89fef, 0x1b0590a0, 0x8801a652, 0xfeb6172b, 0xb8db857d), Nothing}}, SciMLBase.ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#781"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x554580b9, 0x9d2ea56b, 0xc07459a0, 0xc038f528, 0x08cfb40f), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xa41a05c1, 0xcfaffdc7, 0x2760ce99, 0x8c689083, 0x4b780ddd), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ModelingToolkit.ODESystem}, Nothing, ModelingToolkit.ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, u0::Vector{Float64}, ::Val{false}, p::Dict{Any, Any}, ::Val{true}) @ remake.jl:583
_updated_u0_p_internal(prob::SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ModelingToolkit.MTKParameters{Vector{Float64}, StaticArraysCore.SizedVector{0, Any, Vector{Any}}, Tuple{}, Tuple{}, Tuple{}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x33b3bda3, 0xd089b82e, 0xfedcd84c, 0x4037599c, 0x2848c6f6), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1,), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xb8b89fef, 0x1b0590a0, 0x8801a652, 0xfeb6172b, 0xb8db857d), Nothing}}, SciMLBase.ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#781"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x554580b9, 0x9d2ea56b, 0xc07459a0, 0xc038f528, 0x08cfb40f), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xa41a05c1, 0xcfaffdc7, 0x2760ce99, 0x8c689083, 0x4b780ddd), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ModelingToolkit.ODESystem}, Nothing, ModelingToolkit.ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::Missing, p::Vector{Pair{Symbolics.Num, Any}}; interpret_symbolicmap::Bool, use_defaults::Bool) @ remake.jl:493
_updated_u0_p_internal @ remake.jl:482
#updated_u0_p#717 @ remake.jl:645
updated_u0_p @ remake.jl:626
remake(prob::SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ModelingToolkit.MTKParameters{Vector{Float64}, StaticArraysCore.SizedVector{0, Any, Vector{Any}}, Tuple{}, Tuple{}, Tuple{}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x33b3bda3, 0xd089b82e, 0xfedcd84c, 0x4037599c, 0x2848c6f6), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1,), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xb8b89fef, 0x1b0590a0, 0x8801a652, 0xfeb6172b, 0xb8db857d), Nothing}}, SciMLBase.ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#781"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x554580b9, 0x9d2ea56b, 0xc07459a0, 0xc038f528, 0x08cfb40f), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xa41a05c1, 0xcfaffdc7, 0x2760ce99, 0x8c689083, 0x4b780ddd), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ModelingToolkit.ODESystem}, Nothing, ModelingToolkit.ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}; f::Missing, u0::Missing, tspan::Missing, p::Vector{Pair{Symbolics.Num, Any}}, kwargs::Missing, interpret_symbolicmap::Bool, use_defaults::Bool, _kwargs::@Kwargs{}) @ remake.jl:101
solve_instance_fast(P::ForwardDiff.Dual{ForwardDiff.Tag{typeof(Main.var"workspace#8".x_at_1), Float64}, Float64, 1}) @ [Other cell: line 3](http://localhost:1234/edit?id=62e52ee0-6903-11ef-31b8-4357e3ec8460#646340c4-8895-43c8-b3b1-7f0f3e1d603b)
AayushSabharwal commented 2 weeks ago

This needs a type annotation. Parameters that aren't floating point numbers need their types to be explicitly specified.

@parameters P F_spline::CubicSpline
@variables x(t)
@register_symbolic F(t, spline::CubicSpline) # following https://docs.sciml.ai/ModelingToolkit/stable/tutorials/ode_modeling/#Specifying-a-time-variable-forcing-function
sys = structural_simplify(ODESystem([D(x) ~ P * F(t, F_spline)], t, [x], [P, F_spline]; name = :sys))
prob = ODEProblem(sys, [x => 0.0], (0.0, 1.0), [P => NaN, F_spline => DummyCubicSpline()]) # uninitialized problem

function solve_instance_fast(P)
    F_instance = create_F(P)
    prob_instance = remake(prob; p = [sys.P => P, sys.F_spline => F_instance]) # update P and spline for F(t)
    return solve(prob_instance)
end

x_at_1(P) = solve_instance_fast(P)(1.0; idxs=x) # analytical result: x(1) = P/(P+1)
ForwardDiff.derivative(x_at_1, P0) # analytical result: d(x(t=1))/dP = 1/(P+1)^2
cstjean commented 2 weeks ago

Worked for me, thank you for the quick answer. Small detail, your example is missing:

DummyCubicSpline() = create_F(1)