ChrisRackauckas / universal_differential_equations

Repository for the Universal Differential Equations for Scientific Machine Learning paper, describing a computational basis for high performance SciML
https://arxiv.org/abs/2001.04385
MIT License
220 stars 59 forks source link

SEIR Example #44

Open ghost opened 2 years ago

ghost commented 2 years ago

Translation of SEIR Example, based on Lotka Volterra 1:

Hiya, ok, here's the first...

cd(@__DIR__)
using Pkg; Pkg.activate("."); Pkg.instantiate()

# Single experiment, move to ensemble further on
# Some good parameter values are stored as comments right now
# because this is really good practice
using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra 
using SciMLSensitivity
using Random
using Optimization, OptimizationFlux, OptimizationOptimJL #OptimizationFlux for ADAM and OptimizationOptimJL for BFGS
using Lux
using Statistics
using Plots
gr()
#using DiffEqSensitivity**, Optim**
#using DiffEqFlux**, Flux**

function corona!(du,u,p,t)
    S,E,I,R,N,D,C = u
    F, β0,α,κ,μ,σ,γ,d,λ = p
    dS = -β0*S*F/N - β(t,β0,D,N,κ,α)*S*I/N -μ*S # susceptible
    dE = β0*S*F/N + β(t,β0,D,N,κ,α)*S*I/N -(σ+μ)*E # exposed
    dI = σ*E - (γ+μ)*I # infected
    dR = γ*I - μ*R # removed (recovered + dead)
    dN = -μ*N # total population
    dD = d*γ*I - λ*D # severe, critical cases, and deaths
    dC = σ*E # +cumulative cases

    du[1] = dS; du[2] = dE; du[3] = dI; du[4] = dR
    du[5] = dN; du[6] = dD; du[7] = dC
end

β(t,β0,D,N,κ,α) = β0*(1-α)*(1-D/N)^κ
S0 = 14e6
u0 = [0.9*S0, 0.0, 0.0, 0.0, S0, 0.0, 0.0]
p_ = [10.0, 0.5944, 0.4239, 1117.3, 0.02, 1/3, 1/5,0.2, 1/11.2]
R0 = p_[2]/p_[7]*p_[6]/(p_[6]+p_[5])
tspan = (0.0, 21.0)
prob = ODEProblem(corona!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 1)
t = solution.t
#[2:4] are Exposed, Infected, Removed 
X = Array(solution[2:4,:])'
plot(X)

#Extrapolate to a longer timespan
tspan2 = (0.0,60.0)
prob = ODEProblem(corona!, u0, tspan2, p_)
solution_extrapolate = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 1)
extrapolate = Array(solution_extrapolate[2:4,:])'
plot(extrapolate)

# Ideal data
tsdata = Array(solution)

# Add noise to the data
noisy_data = tsdata + Float32(1e-5)*randn(eltype(tsdata), size(tsdata))
# You can see that the noise looks random
plot(abs.(tsdata-noisy_data)')

### Neural ODE
#Predicts for unknown equations
rng = Random.default_rng()
Random.seed!(111)

#7 inputs for 7 equations, 5 outputs because we know 2 equations already
U = Lux.Chain(Lux.Dense(7, 64, tanh),Lux.Dense(64, 64, tanh), Lux.Dense(64, 64, tanh), Lux.Dense(64, 5))

# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)

function coronaNODE(du,u,p,t,p_)
    û = U(u, p, st)[1] # Network prediction
    S,E,I,R,N,D,C = u
    μ,σ = p_
    dS = û[1]
    dE = û[2]
    dI = û[3]
    dR = û[4]
    dN = -μ*N # total population
    dD = û[5]
    dC = σ*E # +cumulative cases
    du[1] = dS; du[2] = dE; du[3] = dI; du[4] = dR
    du[5] = dN; du[6] = dD; du[7] = dC
end

# Closure with the known parameters
NODE_dynamics!(du,u,p,t) = coronaNODE(du,u,p,t,p_)
# Define the problem
prob_node = ODEProblem(NODE_dynamics!, u0, tspan, p)

## Function to train the network
# Define a predictor
function predict(θ, X = noisy_data[:,1], T = t)
    Array(solve(prob_node, Vern7(), u0 = X, p=θ,
                saveat = T,
                abstol=1e-6, reltol=1e-6,
                sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())))
end

# Simple L2 loss
function loss(θ)
    X̂ = predict(θ)
    sum(abs2, noisy_data .- X̂)
end

# Container to track the losses
losses = Float32[]

callback = function (p, l)
  push!(losses, l)
  if length(losses)%50==0
      println("Current loss after $(length(losses)) iterations: $(losses[end])")
  end
  return false
end

## Training

# First train with ADAM for better convergence -> move the parameters into a
# favourable starting positing for BFGS
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p))
res1 = Optimization.solve(optprob, ADAM(0.01), callback=callback, maxiters = 200)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
# Train with BFGS
optprob2 = Optimization.OptimizationProblem(optf, res1.minimizer)
res2 = Optimization.solve(optprob2, Optim.BFGS(initial_stepnorm=0.01), callback=callback, maxiters = 10000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")

# Plot the losses
pl_losses = plot(1:200, losses[1:200], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "ADAM", color = :blue)
plot!(201:length(losses), losses[201:end], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "BFGS", color = :red)
savefig(pl_losses, "pl_lossesNODE.png")

# Rename the best candidate
p_trained = res2.minimizer

## Analysis of the trained network
# Plot the data and the approximation

# Make the prediction to match solution.t
X̂ = predict(p_trained, noisy_data[:,1], t)
# Prediction trained on noisy data vs real solution
pl_trajectory = plot(t, transpose(X̂[2:4,:]), xlabel = "t", ylabel ="x(t), y(t)", color = :red, label = ["NODE Approximation" nothing])
scatter!(solution.t, transpose(noisy_data[2:4,:]), color = :black, label = ["Measurements" nothing])
savefig(pl_trajectory, "plots_trajectory_reconstructionNODE.png")

#Extrapolate the solution to match tspan2
ExtrapolateX̂ = predict(p_trained, noisy_data[:,1], solution_extrapolate.t)
extrapolate_trajectory = plot(solution_extrapolate.t, transpose(ExtrapolateX̂[2:4,:]), xlabel = "t", ylabel ="x(t), y(t)", color = :red, label = ["NODE Approximation" nothing])
scatter!(solution_extrapolate.t, transpose(solution_extrapolate[2:4,:]), color = :black, label = ["Measurements" nothing])
savefig(pl_trajectory, "ExtrapolateNODE.png")

### Universal ODE
##Prediction for missing parameters
rng = Random.default_rng()
Random.seed!(222)

#7 inputs for 7 equations, 1 output for 1 missing part of the equation
U = Lux.Chain(Lux.Dense(7, 64, tanh),Lux.Dense(64, 64, tanh), Lux.Dense(64, 64, tanh), Lux.Dense(64, 1))

# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)

function coronaUDE(du,u,p,t,p_true)
    û = U(u, p, st)[1] # Network prediction
    S,E,I,R,N,D,C = u
    F,β0,α,κ,μ,σ,γ,d,λ = p_
    dS = -β0*S*F/N - û[1] -μ*S # susceptible
    dE = β0*S*F/N + û[1] -(σ+μ)*E # exposed
    dI = σ*E - (γ+μ)*I # infected
    dR = γ*I - μ*R # removed (recovered + dead)
    dN = -μ*N # total population
    dD = d*γ*I - λ*D # severe, critical cases, and deaths
    dC = σ*E # +cumulative cases
    du[1] = dS; du[2] = dE; du[3] = dI; du[4] = dR
    du[5] = dN; du[6] = dD; du[7] = dC
end

# Closure with the known parameters
UDE_dynamics!(du,u,p,t) = coronaUDE(du,u,p,t,p_)
# Define the problem
prob_ude = ODEProblem(UDE_dynamics!, u0, tspan, p)

## Function to train the network
# Define a predictor
function predict(θ, X = noisy_data[:,1], T = t)
    Array(solve(prob_ude, Vern7(), u0 = X, p=θ,
                saveat = T,
                abstol=1e-6, reltol=1e-6,
                sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP())))
end

# Simple L2 loss
function loss(θ)
    X̂ = predict(θ)
    sum(abs2, noisy_data .- X̂)
end

# Container to track the losses
losses = Float32[]

callback = function (p, l)
  push!(losses, l)
  if length(losses)%50==0
      println("Current loss after $(length(losses)) iterations: $(losses[end])")
  end
  return false
end

## Training

# First train with ADAM for better convergence -> move the parameters into a
# favourable starting positing for BFGS
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p))
res1UDE = Optimization.solve(optprob, ADAM(0.01), callback=callback, maxiters = 200)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
# Train with BFGS
optprob2 = Optimization.OptimizationProblem(optf, res1.minimizer)
res2UDE = Optimization.solve(optprob2, Optim.BFGS(initial_stepnorm=0.01), callback=callback, maxiters = 10000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")

# Plot the losses
pl_losses = plot(1:200, losses[1:200], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "ADAM", color = :blue)
plot!(201:length(losses), losses[201:end], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "BFGS", color = :red)
savefig(pl_losses, "plot_lossesUDE.png")
# Rename the best candidate
p_trained = res2UDE.minimizer

## Analysis of the trained network
# Plot the data and the approximation
X̂ = predict(p_trained, noisy_data[:,1], t)
# Trained on noisy data vs real solution
pl_trajectory = plot(t, transpose(X̂[2:4,:]), xlabel = "t", ylabel ="x(t), y(t)", color = :red, label = ["UDE Approximation" nothing])
scatter!(solution.t, transpose(solution[2:4,:]), color = :black, label = ["Measurements" nothing])
savefig(pl_trajectory, "plot_trajectory_reconstructionUDE.png")

# Extrapolate out
ExtrapolateX̂ = predict(p_trained, noisy_data[:,1], solution_extrapolate.t)
extrapolate_trajectory = plot(solution_extrapolate.t, transpose(ExtrapolateX̂[2:4,:]), xlabel = "t", ylabel ="x(t), y(t)", color = :red, label = ["UDE Approximation" nothing])
scatter!(solution_extrapolate.t, transpose(solution_extrapolate[2:4,:]), color = :black, label = ["Measurements" nothing])
savefig(extrapolate_trajectory, "ExtrapolateUDE.png")
ghost commented 2 years ago

p.s. Just realised it was unclear on which is universal and which is neural, edited

ChrisRackauckas commented 2 years ago

@rajdandekar

RajDandekar commented 2 years ago

@ccrnn I also have been working on translating the UDE codes into the SciML Sensitivity + Lux interface. Here are the key points based on your prior comment:

(a) The code I have provided below mimics the original code closely.

(b) The plots for both prediction and estimation match the original plots in the paper.

(c) I have not yet done the SINDY part, but will implement it in the coming days.

Can you have a look at the code below and also compare with yours? This may also improve some of the results you are seeing on your end I guess:

using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra
using Lux,Optimization, OptimizationOptimJL, DiffEqFlux, Flux
using Plots

using Random
rng = Random.default_rng()

function corona!(du,u,p,t)
    S,E,I,R,N,D,C = u
    F, β0,α,κ,μ,σ,γ,d,λ = p
    dS = -β0*S*F/N - β(t,β0,D,N,κ,α)*S*I/N -μ*S # susceptible
    dE = β0*S*F/N + β(t,β0,D,N,κ,α)*S*I/N -(σ+μ)*E # exposed
    dI = σ*E - (γ+μ)*I # infected
    dR = γ*I - μ*R # removed (recovered + dead)
    dN = -μ*N # total population
    dD = d*γ*I - λ*D # severe, critical cases, and deaths
    dC = σ*E # +cumulative cases

    du[1] = dS; du[2] = dE; du[3] = dI; du[4] = dR
    du[5] = dN; du[6] = dD; du[7] = dC
end
β(t,β0,D,N,κ,α) = β0*(1-α)*(1-D/N)^κ
S0 = 14e6
u0 = [0.9*S0, 0.0, 0.0, 0.0, S0, 0.0, 0.0]
p_ = [10.0, 0.5944, 0.4239, 1117.3, 0.02, 1/3, 1/5,0.2, 1/11.2]
R0 = p_[2]/p_[7]*p_[6]/(p_[6]+p_[5])
tspan = (0.0, 21.0)
prob = ODEProblem(corona!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 1)

tspan2 = (0.0,60.0)
prob = ODEProblem(corona!, u0, tspan2, p_)
solution_extrapolate = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 1)

# Ideal data
tsdata = Array(solution)
# Add noise to the data
noisy_data = tsdata + Float32(1e-5)*randn(eltype(tsdata), size(tsdata))

plot(abs.(tsdata-noisy_data)')

### Neural ODE

ann_node = Lux.Chain(Lux.Dense(7, 64, tanh),Lux.Dense(64, 64, tanh), Lux.Dense(64, 64, tanh), Lux.Dense(64, 7))
p1, st1 = Lux.setup(rng, ann_node)
p = Lux.ComponentArray(p1)

function dudt_node(du, u,p,t)
    S,E,I,R,N,D,C = u
    F,β0,α,κ,μ,σ,γ,d,λ = p_
    du[1] = dS =  ann_node([S/N,E,I,R,N,D/N,C], p, st1)[1][1]
    du[2] = dE =  ann_node([S/N,E,I,R,N,D/N,C], p, st1)[1][2]
    du[3] = dI =  ann_node([S/N,E,I,R,N,D/N,C], p, st1)[1][3]
    du[4] = dR =  ann_node([S/N,E,I,R,N,D/N,C], p, st1)[1][4]
    du[5] = dD =  ann_node([S/N,E,I,R,N,D/N,C], p, st1)[1][5]

    du[6] = dN = -μ*N # total population
    du[7] = dC = σ*E # +cumulative cases

    [dS,dE,dI,dR,dN,dD,dC]
end

prob_node = ODEProblem{true}(dudt_node, u0, tspan)

function predict(θ)
    x = Array(solve(prob_node, Tsit5(),p = θ, saveat = 1,abstol=1e-6, reltol=1e-6,
                         sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end

# No regularisation right now
function loss(θ)
    pred = predict(θ)
    loss = sum(abs2, (noisy_data[2:4,:] .- pred[2:4,:]))
    return loss # + 1e-5*sum(sum.(abs, params(ann)))
end

loss(p)

iter = 0
    function callback(θ,l)
      global iter
      iter += 1
      if iter%10 == 0
        println(l)
      end
      return false
    end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
res1 = Optimization.solve(optprob, ADAM(0.0001), callback = callback, maxiters = 1500)

optprob2 = remake(optprob,u0 = res1.u)

res2 = Optimization.solve(optprob2,Optim.BFGS(initial_stepnorm=0.01),
                                        callback=callback,
                                        maxiters = 10000)

data_pred = predict(res2.u)

scatter(solution, vars=[2,3,4], label=["True Exposed" "True Infected" "True Recovered"])
plot!(data_pred[2,:],  label=["Estimated Exposed"])
plot!(data_pred[3,:],  label=["Estimated Infected" ])
plot!(data_pred[4,:],  label=["Estimated Recovered"])

# Plot the losses
# TO DO: plot(losses, yaxis = :log, xaxis = :log, xlabel = "Iterations", ylabel = "Loss")

# Extrapolate out
prob_node_extrapolate = ODEProblem{true}(dudt_node, u0, tspan2)
_sol_node = Array(solve(prob_node_extrapolate, Tsit5(),p = res2.u, saveat = 1,abstol=1e-12, reltol=1e-12,
                     sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))

p_node = scatter(solution_extrapolate, vars=[2,3,4], legend = :topleft, label=["True Exposed" "True Infected" "True Recovered"], title="Neural ODE Extrapolation")
plot!(p_node,_sol_node[2,:], lw = 5, label=["Estimated Exposed"])
plot!(p_node,_sol_node[3,:], lw = 5, label=["Estimated Infected" ])
plot!(p_node,_sol_node[4,:], lw = 5, label=["Estimated Recovered"])
plot!(p_node,[20.99,21.01],[0.0,maximum(hcat(Array(solution_extrapolate[2:4,:]),Array(_sol_node[2:4,:])))],lw=5,color=:black,label="Training Data End")

savefig("neuralode_extrapolation.png")
savefig("neuralode_extrapolation.pdf")

### Universal ODE Part 1

ann = Lux.Chain(Lux.Dense(3, 64, tanh),Lux.Dense(64, 64, tanh), Lux.Dense(64, 1))
p1, st1 = Lux.setup(rng, ann)
p = Lux.ComponentArray(p1)

function dudt_(du, u,p,t)
    S,E,I,R,N,D,C = u
    F, β0,α,κ,μ,σ,γ,d,λ = p_
    z = ann([S/N,I,D/N], p, st1)[1][1]
    du[1] = dS = -β0*S*F/N - z[1] -μ*S # susceptible
    du[2] =  dE = β0*S*F/N + z[1] -(σ+μ)*E # exposed
    du[3] =  dI = σ*E - (γ+μ)*I # infected
    du[4] =  dR = γ*I - μ*R # removed (recovered + dead)
    du[5] =  dN = -μ*N # total population
    du[6] =  dD = d*γ*I - λ*D # severe, critical cases, and deaths
    du[7] =  dC = σ*E # +cumulative cases

end

prob_nn = ODEProblem{true}(dudt_,u0, tspan)

function predict(θ)
    x = Array(solve(prob_nn, Tsit5(),p = θ, saveat = solution.t,abstol=1e-6, reltol=1e-6,
                         sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end

# No regularisation right now

function loss(θ)
    pred = predict(θ)
    loss = sum(abs2, (noisy_data[2:4,:] .- pred[2:4,:]))
    return loss # + 1e-5*sum(sum.(abs, params(ann)))
end

loss(p)

iter = 0
    function callback(θ,l)
      global iter
      iter += 1
      if iter%50 == 0
        println(l)
      end
      return false
    end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
res1 = Optimization.solve(optprob, ADAM(0.01), callback = callback, maxiters = 500)

optprob2 = remake(optprob,u0 = res1.u)

res2 = Optimization.solve(optprob2,Optim.BFGS(initial_stepnorm=0.01),
                                                callback=callback,
                                                maxiters = 550)

uode_sol = predict(res2.u)

scatter(solution, vars=[2,3,4], label=["True Exposed" "True Infected" "True Recovered"])
plot!(uode_sol[2,:],  label=["Estimated Exposed"])
plot!(uode_sol[3,:],  label=["Estimated Infected" ])
plot!(uode_sol[4,:],  label=["Estimated Recovered"])

# Plot the losses
#TO DO: plot(losses, yaxis = :log, xaxis = :log, xlabel = "Iterations", ylabel = "Loss")

# Collect the state trajectory and the derivatives
#X = noisy_data
# Ideal derivatives
#DX = Array(solution(solution.t, Val{1}))

# Extrapolate out
prob_nn2 = ODEProblem{true}(dudt_, u0, tspan2)
_sol_uode = Array(solve(prob_nn2, Tsit5(),p = res2.u, saveat = 1,abstol=1e-12, reltol=1e-12,
                     sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))

p_uode = scatter(solution_extrapolate, vars=[2,3,4], legend = :topleft, label=["True Exposed" "True Infected" "True Recovered"], title="Neural ODE Extrapolation")
plot!(p_uode,_sol_uode[2,:], lw = 5, label=["Estimated Exposed"])
plot!(p_uode,_sol_uode[3,:], lw = 5, label=["Estimated Infected" ])
plot!(p_uode,_sol_uode[4,:], lw = 5, label=["Estimated Recovered"])
plot!(p_uode,[20.99,21.01],[0.0,maximum(hcat(Array(solution_extrapolate[2:4,:]),Array(_sol_uode[2:4,:])))],lw=5,color=:black,label="Training Data End")

savefig("universalode_extrapolation.png")
savefig("universalode_extrapolation.pdf")
ghost commented 2 years ago

Thanks for this - how did you find the right form for the [1][1], [1][2], etc? I was trying to find this! With the component array too. What exactly does the first [1] do?

Not being able to predict for [2:4] was something weird with having u0 in the predict function.

I am still seeing linear approximations for the first example, and incorrect non-linear approximations for the second, with your code too though?

ghost commented 2 years ago

@RajDandekar

RajDandekar commented 2 years ago

@ccrnn: he [1] basically prints out the vector of 5 elements. Then we need to access each element separately through 1 more level of indexing..

Regarding your second question, even in Chris's original code, the Neural ODE and the UDE extrapolations are not good..

See this: https://github.com/ChrisRackauckas/universal_differential_equations/blob/master/SEIR_exposure/neuralode_extrapolation.png

and this: https://github.com/ChrisRackauckas/universal_differential_equations/blob/master/SEIR_exposure/universalode_extrapolation.png

For now, it's good that we match those results with SciML Sensitivity. We can indeed match the results.

We can spend some time later to maybe optimize the code hyperparameters etc to get better results.

SamuelBrand1 commented 6 months ago

Hey @RajDandekar

I note that the SEIR example has not updated to more modern SciML usage? E.g. DiffEqSensitivity is used and sciml_train. Is this still a TBD for you?

ChrisRackauckas commented 6 months ago

Someone needs to spend the time to update all of this. I think we want to maintain it in the SciMLDocs in the near future if someone takes the time.