ChrisRackauckas / universal_differential_equations

Repository for the Universal Differential Equations for Scientific Machine Learning paper, describing a computational basis for high performance SciML
https://arxiv.org/abs/2001.04385
MIT License
215 stars 59 forks source link

LV Scenario 2 #47

Open ghost opened 1 year ago

ghost commented 1 year ago

Incorrect predictions, again....

using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra 
using Optimization, OptimizationFlux, OptimizationOptimJL #OptimizationOptimisers for ADAM and OptimizationOptimJL for BFGS
using Lux
using SciMLSensitivity
using Plots
gr()
using JLD2, FileIO
using Statistics
# Set a random seed for reproduceable behaviour
using Random
Random.seed!(2345)

#### NOTE
# Since the recent release of DataDrivenDiffEq v0.6.0 where a complete overhaul of the optimizers took
# place, SR3 has been used. Right now, STLSQ performs better and has been changed.

#### NOTE
# Since the recent release of DataDrivenDiffEq v0.6.0 where a complete overhaul of the optimizers took
# place, SR3 has been used. Right now, STLSQ performs better and has been changed.

# Create a name for saving ( basically a prefix )
svname = "Scenario_2_"

## Data generation
function lotka!(du, u, p, t)
    α, β, γ, δ = p
    du[1] = α*u[1] - β*u[2]*u[1]
    du[2] = γ*u[1]*u[2]  - δ*u[2]
end

# Define the experimental parameter
tspan = (0.0f0,6.0f0)
u0 = Float32[0.44249296,4.6280594]
p_ = Float32[1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0,tspan, p_)
solution = solve(prob, Vern7(), abstol=1e-6, reltol=1e-6, saveat = 0.1)

scatter(solution, alpha = 0.25)
plot!(solution, alpha = 0.5)

# Ideal data
X = Array(solution)
t = solution.t
# Add noise in terms of the mean
x̄ = mean(X, dims = 2)
noise_magnitude = Float32(1e-2)
Xₙ = X .+ (noise_magnitude*x̄) .* randn(eltype(X), size(X))

# Subsample the data in y
# We assume we have only 5 measurements in y, evenly distributed
ty = collect(t[1]:Float32(6/5):t[end])
# Create datasets for the different measurements
round(Int64, mean(diff(ty))/mean(diff(t)))
XS = zeros(eltype(X), length(ty)-1, floor(Int64, mean(diff(ty))/mean(diff(t)))+1) # All x data
TS = zeros(eltype(t), length(ty)-1, floor(Int64, mean(diff(ty))/mean(diff(t)))+1) # Time data
YS = zeros(eltype(X), length(ty)-1, 2) # Just two measurements in y

for i in 1:length(ty)-1
    idxs = ty[i].<= t .<= ty[i+1]
    XS[i, :] = Xₙ[1, idxs]
    TS[i, :] = t[idxs]
    YS[i, :] = [Xₙ[2, t .== ty[i]]'; Xₙ[2, t .== ty[i+1]]]
end

XS

scatter!(t, transpose(Xₙ))
## Define the network
# Gaussian RBF as activation
rbf(x) = exp.(-(x.^2))

# Define the network 2->5->5->5->2
U = Lux.Chain(
    Lux.Dense(2,5,rbf), Lux.Dense(5,5, rbf), Lux.Dense(5,5, rbf), Lux.Dense(5,2)
)

rng = Random.default_rng()
p1, st = Lux.setup(rng, U)

#for birth, decay parameters -> initializing random values.
parameter_array = Float64[0.5]
p = (layer_1 = p1, layer_2 = parameter_array)
p = Lux.ComponentArray(p)

# Define the hybrid model
function ude_dynamics!(du,u, p, t, p_true)
    û = U(u, p.layer_1, st)[1] # Network prediction
    du[1] = p_true[1]*u[1] + û[1]
    # We assume a linear decay rate for the predator
    du[2] = -p.layer_2[1]*u[2] + û[2]
end

p_true = 1.3
# Closure with the known parameter
nn_dynamics!(du,u,p,t) = ude_dynamics!(du,u,p,t,p_true)
# Define the problem

prob_nn = ODEProblem(nn_dynamics!,Xₙ[:, 1], tspan, p)

## Function to train the network
# Define a predictor
function predict(θ, X = Xₙ[:,1], T = t)
    Array(solve(prob_nn, Vern7(), u0 = X, p=θ,
                saveat = T,
                abstol=1e-6, reltol=1e-6,
                sensealg = ForwardDiffSensitivity()
                ))
end

# Multiple shooting like loss
function loss(θ)
    # Start with a regularization on the network
    l = convert(eltype(θ), 1e-3)*sum(abs2, θ[2:end]) ./ length(θ[2:end])
    for i in 1:size(XS,1)
        X̂ = predict(θ, [XS[i,1], YS[i,1]], TS[i, :])
        # Full prediction in x
        l += sum(abs2, XS[i,:] .- X̂[1,:])
        # Add the boundary condition in y
        l += abs(YS[i, 2] .- X̂[2, end])
    end
    return l
end

# Container to track the losses
losses = Float32[]

# Callback to show the loss during training
callback(θ,l) = begin
    push!(losses, l)
    if length(losses)%50==0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    false
end

## Training

# First train with ADAM for better convergence -> move the parameters into a
# favourable starting positing for BFGS
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p))
res1 = Optimization.solve(optprob, ADAM(0.01f0), callback=callback, maxiters = 300)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
# Train with BFGS to achieve partial fit of the data
optprob2 = remake(optprob,u0 = res1.u)
res2 = Optimization.solve(optprob2, Optim.BFGS(initial_stepnorm=0.01f0), callback=callback, maxiters = 10000, g_tol = 1e-10)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

# Plot the losses
pl_losses = plot(1:300, losses[1:300], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "ADAM", color = :blue)
plot!(301:length(losses), losses[301:end], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "BFGS", color = :red)
savefig(pl_losses, "plot_losses.png")

# Rename the best candidate
p_trained = res2.minimizer

## Analysis of the trained network
# Plot the data and the approximation
ts = first(solution.t):mean(diff(solution.t))/2:last(solution.t)
X̂ = predict(p_trained, Xₙ[:, 1], ts)
# Trained on noisy data vs real solution
pl_trajectory = plot(ts, transpose(X̂), ylabel = "t", xlabel ="x(t), y(t)", color = :red, label = ["UDE Approximation" nothing])
scatter!(t, X[1,:], color = :black, label = "Measurements")
ymeasurements = unique!(vcat(YS...))
tmeasurements = unique!(vcat([[ts[1], ts[end]] for ts in eachrow(TS)]...))
scatter!(tmeasurements, ymeasurements, color = :black, label = nothing, legend = :topleft)
savefig(pl_trajectory, "plot_trajectory_reconstruction.png")

# Ideal unknown interactions of the predictor
Ȳ = [-p_[2]*(X̂[1,:].*X̂[2,:])';p_[3]*(X̂[1,:].*X̂[2,:])']
# Neural network guess
Ŷ = U(X̂, p_trained.layer_1, st)[1]

pl_reconstruction = plot(ts, transpose(Ŷ), xlabel = "t", ylabel ="U(x,y)", color = :red, label = ["UDE Approximation" nothing])
plot!(ts, transpose(Ȳ), color = :black, label = ["True Interaction" nothing], legend = :topleft)
savefig(pl_reconstruction, "plot_missingterm_reconstruction.png")

# Plot the error
pl_reconstruction_error = plot(ts, norm.(eachcol(Ȳ-Ŷ)), yaxis = :log, xlabel = "t", ylabel = "L2-Error", color = :red, label = nothing)
pl_missing = plot(pl_reconstruction, pl_reconstruction_error, layout = (2,1))
savefig(pl_missing, "plots_missingterm_reconstruction_and_error.pdf")
pl_overall = plot(pl_trajectory, pl_missing)
savefig(pl_overall, "plots_reconstruction.png")

@RajDandekar

AlCap23 commented 1 year ago

I've adapted scenario 1 - 3 so it should be working in #48

ghost commented 1 year ago

Thaaaanks! How did you find the right format for the setup for this part? Is there documentation on this somewhere or do you just know it?

# Merge the parameters
p = (;δ = rand(rng), ude = p_nn)
p = ComponentVector{Float64}(p)
# Define the hybrid model
function ude_dynamics!(du,u, p, t, p_true)
    û = U(u, p.ude, st_nn)[1] # Network prediction
    du[1] = p_true[1]*u[1] + û[1]
    # We assume a linear decay rate for the predator
    du[2] = -p.δ*u[2] + û[2]
end
AlCap23 commented 1 year ago

I've recently switched to Lux and setup some models in the past. NamedTuples and ComponentVectors are really helpful in structuring the overall parameters and both are useable for AD.

Additionally, I found this tutorial to be most insightful. Since I do not need a specific state, I just drop the information.

A more general way would be something along the lines of

mutable struct LuxContainer
  model
  state
end

(c::LuxContainer)(x, p) = begin 
   out, state = c.model(x, p, c.state)
   c.state = state
  return out
end
ghost commented 1 year ago

Thanks, that's helpful. I guess I need to slow down and try to really work step by step through these tutorials instead of scanning for stuff that looks like it might fit. I feel like i'm missing the "why" some things work and others don't...