SciML / DifferentialEquations.jl

Multi-language suite for high-performance solvers of differential equations and scientific machine learning (SciML) components. Ordinary differential equations (ODEs), stochastic differential equations (SDEs), delay differential equations (DDEs), differential-algebraic equations (DAEs), and more in Julia.
https://docs.sciml.ai/DiffEqDocs/stable/
Other
2.87k stars 230 forks source link

Simultaneous parameter and initial conditions estimation for ODEs #717

Closed ClaudMor closed 3 years ago

ClaudMor commented 3 years ago

Hello,

I was wondering if it was possible to simultaneously optimize an ODE's parameters and initial conditions w.r.t. some given dataset.

The approach i followed so far was:

  1. randomly sample the initial conditions
  2. calibrate the model's parameters with those initial conditions ( e.g. with Optim), and save the triple (initial conditions, calibrated_parameters, final loss).
  3. repeat steps 1. and 2. as many times as you wish
  4. select the triple that correspond to minimum loss.

I know I could probably use Turing.jl, but I can't make it scale very well with parameters number. I also read that the prob_generator function switches the roles of initial conditions and model parameters, but I didn't really understand if it could help in this mixed situation.

So I'd like to ask which are ( if any) the ways to calibrate initial conditions and parameters values simultaneously.

If you'd like to give an explicit example, I'll post below a simple model with hard coded calibration data:

using ModelingToolkit, DifferentialEquations

# Model parameters

β = 0.01 # infection rate
λ_R = 0.05 # inverse of transition time from  infected to recovered
λ_D = 0.83 # inverse of transition time from  infected to dead

𝒫 = [β, λ_R, λ_D]

# regional contact matrix and regional population

## regional contact matrix
const C = [3.45536  0.485314  0.506389  0.123002;
           0.597721  2.11738   0.911374  0.323385;
           0.906231  1.35041   1.60756   0.67411;
           0.237902  0.432631  0.726488  0.979258] # 4x4 contact matrix

## regional population stratified by age
const N = [723208, 874150, 1330993, 1411928] # array of 4 elements, each of which representing the absolute amount of population in the corresponding age class.

# Initial conditions 
i₀ = 0.075 # fraction of initial infected people in every age class
I₀ = fill(i₀, 4)
S₀ = N .- I₀
R₀ = zeros(length(N))
D₀ = zeros(length(N))
D_tot₀ = 0.0
ℬ = vcat(S₀, I₀, R₀, D₀, D_tot₀)

# Time 
final_time = 20
𝒯 = (1.0, final_time)

# initialize this parameter (death probability stratified by age, taken from literature)
const δ = [0.003/100, 0.004/100, (0.015 + 0.030 + 0.064 + 0.213 + 0.718) / (5 * 100), (2.384 + 8.466 + 12.497 + 1.117) / (4 * 100)]

function SIRD!(du, u, p, t)
    # extract parameters
    β, λ_R, λ_D = p

    # State variables
    S = @view u[1:4]
    I = @view u[5:8]
    R = @view u[9:12]
    D = @view u[13:16]
    D_tot = u[17]

    # Differentials
    dS = @view du[1:4]
    dI = @view du[5:8]
    dR = @view du[9:12]
    dD = @view du[13:16]

    # Force of infection
    Λ = β * [sum(C[i, j] * I[j] / N[j] for j in 1:size(C, 2)) for i in 1:size(C, 1)]

    # System of equations
    @. dS = -Λ * S
    @. dI = Λ * S - ((1-δ) * λ_R + δ * λ_D) * I
    @. dR = λ_R * (1-δ) * I
    @. dD = λ_D * δ * I
    du[end] = sum(dD)
end

# create problem and check it works
problem  = ODEProblem(SIRD!, ℬ, 𝒯, 𝒫)
solution = solve(problem, Tsit5(), saveat = 1:final_time)

# modelingtoolkitize it
sys = modelingtoolkitize(problem)
fast_problem = ODEProblem(sys, ℬ, 𝒯, 𝒫)
fast_solution = solve(fast_problem, Tsit5(), saveat = 1:final_time)

# death calibration data ( this time series should be reproduced by the D_tot variable)
const real_deaths = [0,2,4,5,5,13,17,21,26,46,59,81,111,133,154,175,209,238,283,315]

Alternatively, if you'd like to use a more complex model already integrated with Optim, Turing's NUTS and ADVI, you may take a look at the MWE from here.

Thanks in advance

ChrisRackauckas commented 3 years ago

This is an example doing 1 initial condition and all parameters: https://diffeqflux.sciml.ai/dev/examples/feedback_control/

ClaudMor commented 3 years ago

Thanks for the reference.

Unfortunately, I'm not yet comfortable enough about neural ODEs, so I was planning to use simpler tools like DiffEqParamEstim ( or others?). Anyway, from what I saw in the reference I thought I could probably make a custom function like predict_univ, and then pass it to build_loss_objective and then using Optim, as described here. But build_loss_objective takes ::DEProblem, so I'm not sure it will work this way. Does a working variation of something like this exist? Or would you have any other suggestion?

ChrisRackauckas commented 3 years ago

You're trying to make a parameter estimation system that doesn't do this do this, instead of using a parameter estimation system that does do this to do this.

using DifferentialEquations, Flux, Optim, DiffEqFlux, DiffEqSensitivity, Plots

function lotka_volterra!(du, u, p, t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end

# Initial condition
u0 = [1.0, 1.0]

# Simulation interval and intermediary points
tspan = (0.0, 10.0)
tsteps = 0.0:0.1:10.0

# LV equation parameter. p = [α, β, δ, γ]
p = [1.5, 1.0, 3.0, 1.0]
theta = [u0;p]

# Setup the ODE problem, then solve
prob = ODEProblem(lotka_volterra!, u0, tspan, p)
sol = solve(prob, Tsit5())

# Plot the solution
using Plots
plot(sol)
savefig("LV_ode.png")

function loss(theta)
  _prob = remake(prob,u0=theta[1:2],p=[3:end])
  sol = solve(_prob, Tsit5(), saveat = tsteps)
  loss = sum(abs2, sol.-1)
  return loss, sol
end

callback = function (p, l, pred)
  display(l)
  plt = plot(pred, ylim = (0, 6))
  display(plt)
  # Tell sciml_train to not halt the optimization. If return true, then
  # optimization stops.
  return false
end

result_ode = DiffEqFlux.sciml_train(loss, theta,
                                    ADAM(0.1),
                                    cb = callback,
                                    maxiters = 100)

Is a relatively straightforward example. I would highly suggest just using the right tool.

ClaudMor commented 3 years ago

Thank you very much,

I just noticed that the line:

_prob = remake(prob,u0=theta[1:2],p=[3:end])

should maybe be changed to:

_prob = remake(prob,u0=theta[1:2],p=theta[3:end])