SciML / SteadyStateDiffEq.jl

Solvers for steady states in scientific machine learning (SciML)
Other
30 stars 22 forks source link

f does not fix t=0 #44

Closed elbert5770 closed 1 year ago

elbert5770 commented 2 years ago

From the docs "The steady state solvers interpret the f by fixing t=0.". But it doesn't seem to do this.

The example below works, but if you instead define the SteadyStateProblem with 'SimpleKinetics!' instead of 'SimpleKineticsSS', it crashes.

using DifferentialEquations, Optimization, OptimizationPolyalgorithms, OptimizationOptimJL,
      SciMLSensitivity, Zygote, Plots, BenchmarkTools,ForwardDiff,NLsolve, Statistics

function SimpleKinetics!(du, u, p, t)
    x, y,x2,y2 = u
    β, δ = p

    du[1] =   (1.0-0.1*t) - β*x
    du[2] =  -δ*y + β*x
    du[3] =   0.1*t - β*x2
    du[4] =  -δ*y2 + β*x2
end

function SimpleKineticsSS(du, u, p, t)
    t = 0.0
    SimpleKinetics!(du, u, p, t)
end

function loss_wrapper(prob,tsteps,probSS)
    function loss_closure(p)
        u0=solve(remake(probSS,p=p),SSRootfind())

        sol = solve(remake(prob,u0=u0,p=p), Tsit5(), saveat = tsteps)

        loss = sum(abs2, sol.-0.5) 
        @show loss
        return loss, sol
    end
end

callback = function (p, l, pred)

  plt = plot(pred, ylim = (0, 6))
  display(plt)

  return false
end

function main()
    tspan = (0.0, 10.0)
    tsteps = 0.0:0.1:10.0
    p = [1.0, 3.0]
    u0 = [2.0, 1.0,0.0,0.0]

    prob = ODEProblem(SimpleKinetics!, u0, tspan, p)

    probSS = SteadyStateProblem{true}(SimpleKineticsSS, u0, p)

    solODE = solve(remake(prob,u0=u0,p=p), Tsit5(), saveat = tsteps)
    display(plot(solODE))

    u0=solve(probSS,SSRootfind())
    solODE = solve(remake(prob,u0=u0,p=p), Tsit5(), saveat = tsteps)
    display(plot(solODE))

    loss = loss_wrapper(prob,tsteps,probSS)

    adtype = Optimization.AutoForwardDiff()
    optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype)
    optprob = Optimization.OptimizationProblem(optf, p)

    sol = Optimization.solve(optprob, PolyOpt(), callback = callback, maxiters = 100)

    p =sol.u
    @show p
    u0=solve(remake(probSS,p=p),SSRootfind())
    solODE = solve(remake(prob,u0=u0,p=p), Tsit5(), saveat = tsteps)
    display(plot(solODE))

    @show sum(abs2, solODE.-0.5) 
end

main()
albheim commented 1 year ago

https://discourse.julialang.org/t/steadystateproblem-solved-result-problem/88511/10

ChrisRackauckas commented 1 year ago

It should say to infinity.