SciML / ModelingToolkit.jl

An acausal modeling framework for automatically parallelized scientific machine learning (SciML) in Julia. A computer algebra system for integrated symbolics for physics-informed machine learning and automated transformations of differential equations
https://mtk.sciml.ai/dev/
Other
1.38k stars 196 forks source link

Cannot evaluate spline parameter directly in an ODESystem #2823

Open hersle opened 5 days ago

hersle commented 5 days ago

I want to evaluate a spline and its derivative(s) through an ODESystem. This works beautifully:

using Test
using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D
using DifferentialEquations
using DataInterpolations

ts = 0.0:0.1:1.0
Fspline_t² = CubicSpline(ts .^ 2, ts) # spline for F(t) = t²

@variables F(t) F′(t)
@named sys = ODESystem([
    F ~ Fspline_t²(t)
    F′ ~ D(F)
], t)
ssys = structural_simplify(sys)

prob = ODEProblem(ssys, [], (0.0, 1.0), [])
sol = solve(prob)
@test sol(0.5, idxs=F′) ≈ 2 * 0.5 # F′(t) = 2t

However, to pass different splines, I want to make it a parameter:

using Test
using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D
using DifferentialEquations
using DataInterpolations

ts = 0.0:0.1:1.0
Fspline_t² = CubicSpline(ts .^ 2, ts) # spline for F(t) = t²

@parameters Fspline::CubicSpline = Fspline_t²
@variables F(t) F′(t)
@named sys = ODESystem([
    F ~ Fspline(t)
    F′ ~ D(F)
], t)
ssys = structural_simplify(sys)

prob = ODEProblem(ssys, [], (0.0, 1.0), [])
sol = solve(prob)
@test sol(0.5, idxs=F′) ≈ 2 * 0.5 # F′(t) = 2t

This fails with ERROR: LoadError: Sym Fspline is not callable. Use @syms Fspline(var1, var2,...) to create it as a callable. Accordingly, I tried to declare @parameters Fspline(t)::CubicSpline = Fspline_t² instead, but it does not help. Is there a bug here?

I have worked around half the problem by evaluating the spline like F ~ spleval(x, Fspline) through a proxy function

@register_symbolic spleval(x, spline::CubicSpline)
spleval(x, spline) = spline(x)

This works for F. But it does not work for F′, which then evaluates to e.g. Differential(t)(1.0). Also, it feels unnecessarily complicated, because I believe these derivative rules are already defined in DataInterpolations.jl.

I really wish that the second example worked like the first, without the complicating workarounds. Is that possible?

Thanks!