Open ghost opened 2 years ago
p.s. Just realised it was unclear on which is universal and which is neural, edited
@rajdandekar
@ccrnn I also have been working on translating the UDE codes into the SciML Sensitivity + Lux interface. Here are the key points based on your prior comment:
(a) The code I have provided below mimics the original code closely.
(b) The plots for both prediction and estimation match the original plots in the paper.
(c) I have not yet done the SINDY part, but will implement it in the coming days.
Can you have a look at the code below and also compare with yours? This may also improve some of the results you are seeing on your end I guess:
using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra
using Lux,Optimization, OptimizationOptimJL, DiffEqFlux, Flux
using Plots
using Random
rng = Random.default_rng()
function corona!(du,u,p,t)
S,E,I,R,N,D,C = u
F, β0,α,κ,μ,σ,γ,d,λ = p
dS = -β0*S*F/N - β(t,β0,D,N,κ,α)*S*I/N -μ*S # susceptible
dE = β0*S*F/N + β(t,β0,D,N,κ,α)*S*I/N -(σ+μ)*E # exposed
dI = σ*E - (γ+μ)*I # infected
dR = γ*I - μ*R # removed (recovered + dead)
dN = -μ*N # total population
dD = d*γ*I - λ*D # severe, critical cases, and deaths
dC = σ*E # +cumulative cases
du[1] = dS; du[2] = dE; du[3] = dI; du[4] = dR
du[5] = dN; du[6] = dD; du[7] = dC
end
β(t,β0,D,N,κ,α) = β0*(1-α)*(1-D/N)^κ
S0 = 14e6
u0 = [0.9*S0, 0.0, 0.0, 0.0, S0, 0.0, 0.0]
p_ = [10.0, 0.5944, 0.4239, 1117.3, 0.02, 1/3, 1/5,0.2, 1/11.2]
R0 = p_[2]/p_[7]*p_[6]/(p_[6]+p_[5])
tspan = (0.0, 21.0)
prob = ODEProblem(corona!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 1)
tspan2 = (0.0,60.0)
prob = ODEProblem(corona!, u0, tspan2, p_)
solution_extrapolate = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 1)
# Ideal data
tsdata = Array(solution)
# Add noise to the data
noisy_data = tsdata + Float32(1e-5)*randn(eltype(tsdata), size(tsdata))
plot(abs.(tsdata-noisy_data)')
### Neural ODE
ann_node = Lux.Chain(Lux.Dense(7, 64, tanh),Lux.Dense(64, 64, tanh), Lux.Dense(64, 64, tanh), Lux.Dense(64, 7))
p1, st1 = Lux.setup(rng, ann_node)
p = Lux.ComponentArray(p1)
function dudt_node(du, u,p,t)
S,E,I,R,N,D,C = u
F,β0,α,κ,μ,σ,γ,d,λ = p_
du[1] = dS = ann_node([S/N,E,I,R,N,D/N,C], p, st1)[1][1]
du[2] = dE = ann_node([S/N,E,I,R,N,D/N,C], p, st1)[1][2]
du[3] = dI = ann_node([S/N,E,I,R,N,D/N,C], p, st1)[1][3]
du[4] = dR = ann_node([S/N,E,I,R,N,D/N,C], p, st1)[1][4]
du[5] = dD = ann_node([S/N,E,I,R,N,D/N,C], p, st1)[1][5]
du[6] = dN = -μ*N # total population
du[7] = dC = σ*E # +cumulative cases
[dS,dE,dI,dR,dN,dD,dC]
end
prob_node = ODEProblem{true}(dudt_node, u0, tspan)
function predict(θ)
x = Array(solve(prob_node, Tsit5(),p = θ, saveat = 1,abstol=1e-6, reltol=1e-6,
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end
# No regularisation right now
function loss(θ)
pred = predict(θ)
loss = sum(abs2, (noisy_data[2:4,:] .- pred[2:4,:]))
return loss # + 1e-5*sum(sum.(abs, params(ann)))
end
loss(p)
iter = 0
function callback(θ,l)
global iter
iter += 1
if iter%10 == 0
println(l)
end
return false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
res1 = Optimization.solve(optprob, ADAM(0.0001), callback = callback, maxiters = 1500)
optprob2 = remake(optprob,u0 = res1.u)
res2 = Optimization.solve(optprob2,Optim.BFGS(initial_stepnorm=0.01),
callback=callback,
maxiters = 10000)
data_pred = predict(res2.u)
scatter(solution, vars=[2,3,4], label=["True Exposed" "True Infected" "True Recovered"])
plot!(data_pred[2,:], label=["Estimated Exposed"])
plot!(data_pred[3,:], label=["Estimated Infected" ])
plot!(data_pred[4,:], label=["Estimated Recovered"])
# Plot the losses
# TO DO: plot(losses, yaxis = :log, xaxis = :log, xlabel = "Iterations", ylabel = "Loss")
# Extrapolate out
prob_node_extrapolate = ODEProblem{true}(dudt_node, u0, tspan2)
_sol_node = Array(solve(prob_node_extrapolate, Tsit5(),p = res2.u, saveat = 1,abstol=1e-12, reltol=1e-12,
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
p_node = scatter(solution_extrapolate, vars=[2,3,4], legend = :topleft, label=["True Exposed" "True Infected" "True Recovered"], title="Neural ODE Extrapolation")
plot!(p_node,_sol_node[2,:], lw = 5, label=["Estimated Exposed"])
plot!(p_node,_sol_node[3,:], lw = 5, label=["Estimated Infected" ])
plot!(p_node,_sol_node[4,:], lw = 5, label=["Estimated Recovered"])
plot!(p_node,[20.99,21.01],[0.0,maximum(hcat(Array(solution_extrapolate[2:4,:]),Array(_sol_node[2:4,:])))],lw=5,color=:black,label="Training Data End")
savefig("neuralode_extrapolation.png")
savefig("neuralode_extrapolation.pdf")
### Universal ODE Part 1
ann = Lux.Chain(Lux.Dense(3, 64, tanh),Lux.Dense(64, 64, tanh), Lux.Dense(64, 1))
p1, st1 = Lux.setup(rng, ann)
p = Lux.ComponentArray(p1)
function dudt_(du, u,p,t)
S,E,I,R,N,D,C = u
F, β0,α,κ,μ,σ,γ,d,λ = p_
z = ann([S/N,I,D/N], p, st1)[1][1]
du[1] = dS = -β0*S*F/N - z[1] -μ*S # susceptible
du[2] = dE = β0*S*F/N + z[1] -(σ+μ)*E # exposed
du[3] = dI = σ*E - (γ+μ)*I # infected
du[4] = dR = γ*I - μ*R # removed (recovered + dead)
du[5] = dN = -μ*N # total population
du[6] = dD = d*γ*I - λ*D # severe, critical cases, and deaths
du[7] = dC = σ*E # +cumulative cases
end
prob_nn = ODEProblem{true}(dudt_,u0, tspan)
function predict(θ)
x = Array(solve(prob_nn, Tsit5(),p = θ, saveat = solution.t,abstol=1e-6, reltol=1e-6,
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
end
# No regularisation right now
function loss(θ)
pred = predict(θ)
loss = sum(abs2, (noisy_data[2:4,:] .- pred[2:4,:]))
return loss # + 1e-5*sum(sum.(abs, params(ann)))
end
loss(p)
iter = 0
function callback(θ,l)
global iter
iter += 1
if iter%50 == 0
println(l)
end
return false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
res1 = Optimization.solve(optprob, ADAM(0.01), callback = callback, maxiters = 500)
optprob2 = remake(optprob,u0 = res1.u)
res2 = Optimization.solve(optprob2,Optim.BFGS(initial_stepnorm=0.01),
callback=callback,
maxiters = 550)
uode_sol = predict(res2.u)
scatter(solution, vars=[2,3,4], label=["True Exposed" "True Infected" "True Recovered"])
plot!(uode_sol[2,:], label=["Estimated Exposed"])
plot!(uode_sol[3,:], label=["Estimated Infected" ])
plot!(uode_sol[4,:], label=["Estimated Recovered"])
# Plot the losses
#TO DO: plot(losses, yaxis = :log, xaxis = :log, xlabel = "Iterations", ylabel = "Loss")
# Collect the state trajectory and the derivatives
#X = noisy_data
# Ideal derivatives
#DX = Array(solution(solution.t, Val{1}))
# Extrapolate out
prob_nn2 = ODEProblem{true}(dudt_, u0, tspan2)
_sol_uode = Array(solve(prob_nn2, Tsit5(),p = res2.u, saveat = 1,abstol=1e-12, reltol=1e-12,
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))))
p_uode = scatter(solution_extrapolate, vars=[2,3,4], legend = :topleft, label=["True Exposed" "True Infected" "True Recovered"], title="Neural ODE Extrapolation")
plot!(p_uode,_sol_uode[2,:], lw = 5, label=["Estimated Exposed"])
plot!(p_uode,_sol_uode[3,:], lw = 5, label=["Estimated Infected" ])
plot!(p_uode,_sol_uode[4,:], lw = 5, label=["Estimated Recovered"])
plot!(p_uode,[20.99,21.01],[0.0,maximum(hcat(Array(solution_extrapolate[2:4,:]),Array(_sol_uode[2:4,:])))],lw=5,color=:black,label="Training Data End")
savefig("universalode_extrapolation.png")
savefig("universalode_extrapolation.pdf")
Thanks for this - how did you find the right form for the [1][1], [1][2], etc? I was trying to find this! With the component array too. What exactly does the first [1] do?
Not being able to predict for [2:4] was something weird with having u0 in the predict function.
I am still seeing linear approximations for the first example, and incorrect non-linear approximations for the second, with your code too though?
@RajDandekar
@ccrnn: he [1] basically prints out the vector of 5 elements. Then we need to access each element separately through 1 more level of indexing..
Regarding your second question, even in Chris's original code, the Neural ODE and the UDE extrapolations are not good..
For now, it's good that we match those results with SciML Sensitivity. We can indeed match the results.
We can spend some time later to maybe optimize the code hyperparameters etc to get better results.
Hey @RajDandekar
I note that the SEIR example has not updated to more modern SciML usage? E.g. DiffEqSensitivity
is used and sciml_train
. Is this still a TBD for you?
Someone needs to spend the time to update all of this. I think we want to maintain it in the SciMLDocs in the near future if someone takes the time.
Translation of SEIR Example, based on Lotka Volterra 1:
Hiya, ok, here's the first...