SciML / DataDrivenDiffEq.jl

Data driven modeling and automated discovery of dynamical systems for the SciML Scientific Machine Learning organization
https://docs.sciml.ai/DataDrivenDiffEq/stable/
MIT License
404 stars 57 forks source link

Koopman inference fails for simple example while SINDy works #342

Closed sdwfrost closed 2 years ago

sdwfrost commented 2 years ago

This is from this post on Discourse. The code is cut and paste below - Koopman inference gives the wrong result, while SINDy works fine.

using OrdinaryDiffEq
using DataDrivenDiffEq
using ModelingToolkit

function sir_ode(u,p,t)
    (s,i,r) = u
    (β,γ) = p
    ds = -β*s*i
    di = β*s*i - γ*i
    dr = γ*i
    [ds,di,dr]
end

p = [0.5,0.25]
u0 = [0.99, 0.01, 0.0]
tspan = (0.0, 40.0)
solver = ExplicitRK()
sir_prob = ODEProblem(sir_ode, u0, tspan, p)
sir_sol = solve(sir_prob, solver)
dd_prob = ContinuousDataDrivenProblem(sir_sol)
@parameters t
@variables u[1:3](t)
Ψ = Basis([u; u[1]*u[2]], u, independent_variable = t)
res_koopman = solve(dd_prob, Ψ, DMDPINV(), digits = 1)
sys_koopman = result(res_koopman)

The above gives:

julia> equations(sys_koopman)
3-element Vector{Equation}:
 Differential(t)(u[1](t)) ~ p₁*u[1](t)*u[2](t)
 Differential(t)(u[2](t)) ~ p₂*u[2](t) + p₃*u[1](t)*u[2](t)
 Differential(t)(u[3](t)) ~ p₄*u[2](t)
julia> parameter_map(res_koopman)
4-element Vector{Pair{Sym{Real, Base.ImmutableDict{DataType, Any}}, Float64}}:
 p₁ => -0.5
 p₂ => -0.2
 p₃ => 0.5
 p₄ => 0.2

p₂ and p₄ should be -0.25 and 0.25, which is what I get with SINDy.

AlCap23 commented 2 years ago

Reconfiguring the digits keyword works out for me:

# As above...
res_koopman = solve(dd_prob, Ψ, DMDPINV(), digits = 2)
sys_koopman = result(res_koopman)
println(sys_koopman)
println(parameters(res_koopman))

Gives:

Model ##Koopman#309 with 3 equations
States : u[1](t) u[2](t) u[3](t)
Parameters : p₁ p₂ p₃ p₄
Independent variable: t
Equations
Differential(t)(u[1](t)) = p₁*u[1](t)*u[2](t)
Differential(t)(u[2](t)) = p₂*u[2](t) + p₃*u[1](t)*u[2](t)
Differential(t)(u[3](t)) = p₄*u[2](t)

[-0.5, -0.25, 0.5, 0.25]

Which is very nice 😄 . May I take this example for future docs and presentations ?

sdwfrost commented 2 years ago

Thanks @AlCap23! Perhaps my issue should be that digits does not work with SINDy - running solve(dd_prob, Ψ, STLSQ(),digits=1) DOES give the results to the (correct) 2 significant digits.

Please, feel free to use this example! I'll post a full example at http://github.com/epirecipes/sir-julia soon.

sdwfrost commented 2 years ago

My example is here