Closed ba2tro closed 2 years ago
Or is something missing? Because the p̂ in the first line is also not defined.
I get this plot using recovered_dynamics!
and parameters(nn_res)
as done in this line
https://github.com/ChrisRackauckas/universal_differential_equations/blob/9fd95d7a9f920279beb0733be36bfa96b0cfb3bf/LotkaVolterra/scenario_1.jl#L182
There are no deprecations in the code now, just the above things
The problem here seems to be not being able to achieve a low loss fit to the neural network. Every time I use ADAM and BFGS it doesn't go below ~0.06. With PolyOpt I was able to achieve around 0.03. If we can fit the neural network better here the overall fit of the DataDrivenProblem could come out better
.
what tolerance for the ODE solver? Did you try decreasing it at all? And what about the learning rates?
I tried a bunch of different tolerances from 1e-6 to 1e-10 , along with some learning rates for ADAM but the loss won't go down. The same problem occurs with the deprecated version.
What happens in the BFGS steps?
It terminates at around 950 steps with a loss of 0.06. I even tried to run it again with remake but then it won't btake even one step and terminates straightaway.
Using the v1.0 release of this repository, I was able to reproduce the original result of the NN fitting on Julia v1.5.x.
I suspect the constructor for the NN has changed + the random nummer gen
check the gradients for the same weights.
Do not check out.
After training:
I've been able to give a working example, albeit with different IC for the LV system and using MLE:
## Environment and packages
cd(@__DIR__)
using Pkg; Pkg.activate("."); Pkg.instantiate()
using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra
using DiffEqSensitivity
using Optimization
using OptimizationFlux
using OptimizationOptimJL
using Lux
using ComponentArrays
using Plots
gr()
using JLD2, FileIO
using Statistics
using Distributions
using Random
Random.seed!(1234)
#### NOTE
# Since the recent release of DataDrivenDiffEq v0.6.0 where a complete overhaul of the optimizers took
# place, SR3 has been used. Right now, STLSQ performs better and has been changed.
# Create a name for saving ( basically a prefix )
svname = "Scenario_1_"
## Data generation
function lotka!(du, u, p, t)
α, β, γ, δ = p
du[1] = α*u[1] - β*u[2]*u[1]
du[2] = γ*u[1]*u[2] - δ*u[2]
end
# Define the experimental parameter
tspan = (0.0f0,3.0f0)
u0 = rand(Float32, 2) .* 5.0f0#Float32[0.44249296,4.6280594]
p_ = Float32[1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0,tspan, p_)
solution = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 0.1)
# Ideal data
X = Array(solution)
t = solution.t
# Add noise in terms of the mean
x̄ = mean(X, dims = 2)
noise_magnitude = Float32(5e-2)
Xₙ = X .+ (noise_magnitude*x̄) .* randn(eltype(X), size(X))
plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing])
scatter!(t, transpose(Xₙ), color = :red, label = ["Noisy Data" nothing])
## Define the network
# Gaussian RBF as activation
activation = tanh
# Multilayer FeedForward
U = Lux.Chain(
Lux.Dense(2,10,activation), Lux.Dense(10,10, activation),Lux.Dense(10,10, activation), Lux.Dense(10,2)
)
# Get the initial parameters
rng = Random.default_rng()
#Random.seed!(rng, 0)
# Parameter and State Variables
ps, st = Lux.setup(rng, U)
p = ComponentVector((;logσ = zeros(Float32, 2), parameters = ps))
# Define the hybrid model
ude_dynamics! = (du,u, p, t) -> let p_true = p_, st = st
û, _ = U(u, p.parameters, st) # Network prediction
du[1] = p_true[1]*u[1] + û[1]
du[2] = -p_true[4]*u[2] + û[2]
end
# Define the problem
prob_nn = ODEProblem(ude_dynamics!, Xₙ[:, 1], tspan, p)
## Function to train the network
# Define a predictor
function predict(θ, X = Xₙ[:,1], T = t)
solve(prob_nn, Vern7(), u0 = X, p=θ,
tspan = (T[1], T[end]), saveat = T,
abstol=1e-5, reltol=1e-5,
sensealg = ForwardDiffSensitivity(),
kwargshandle=KeywordArgSilent
)
end
predict(p)
# Simple L2 loss
loss = (θ) -> let d = Normal
X̂ = Array(predict(θ))
size(X̂) != size(Xₙ) && return Inf
dists = map(xi->Normal(zero(eltype(θ)), xi), exp.(θ.logσ))
e = Xₙ .- X̂
l = zero(eltype(θ))
for i in 1:size(X̂, 1), j in 1:size(X̂, 2)
l -= logpdf(dists[i], e[i,j])
end
l + 0.001f0*sum(abs2, θ)
end
loss(p)
# Container to track the losses
losses = Float32[]
# Callback to show the loss during training
callback(θ,l) = begin
push!(losses, l)
if length(losses)%50==0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
false
end
## Training
# First train with ADAM for better convergence -> move the parameters into a
# favourable starting positing for BFGS
optf = OptimizationFunction((x, p)->loss(x), Optimization.AutoForwardDiff())
optprob = OptimizationProblem(optf, p)
res1 = solve(optprob, ADAM(), callback=callback, maxiters = 200)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
# Train with BFGS
optprob = OptimizationProblem(optf, res1.minimizer)
res2 = solve(optprob, BFGS(initial_stepnorm = 0.001f0), callback=callback, maxiters = 10000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")
# Plot the losses
pl_losses = plot(1:200, losses[1:200], xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "ADAM", color = :blue)
plot!(201:length(losses), losses[201:end], xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "BFGS", color = :red)
savefig(pl_losses, joinpath(pwd(), "plots", "$(svname)_losses.pdf"))
# Rename the best candidate
p_trained = res2.minimizer
## Analysis of the trained network
# Plot the data and the approximation
ts = first(solution.t):mean(diff(solution.t)):last(solution.t)
sol_ = predict(p_trained, Xₙ[:,1], ts)
X̂ = Array(sol_)
# Trained on noisy data vs real solution
pl_trajectory = plot(ts, transpose(X̂), xlabel = "t", ylabel ="x(t), y(t)", color = :red, label = ["UDE Approximation" nothing])
scatter!(solution.t, transpose(Xₙ), color = :black, label = ["Measurements" nothing])
savefig(pl_trajectory, joinpath(pwd(), "plots", "$(svname)_trajectory_reconstruction.pdf"))
# Ideal unknown interactions of the predictor
Ȳ = [-p_[2]*(X̂[1,:].*X̂[2,:])';p_[3]*(X̂[1,:].*X̂[2,:])']
# Neural network guess
Ŷ, _ = U(X̂, p_trained.parameters, st)
pl_reconstruction = scatter(ts, transpose(Ŷ), xlabel = "t", ylabel ="U(x,y)", color = :red, label = ["UDE Approximation" nothing])
plot!(ts, transpose(Ȳ), color = :black, label = ["True Interaction" nothing])
savefig(pl_reconstruction, joinpath(pwd(), "plots", "$(svname)_missingterm_reconstruction.pdf"))
# Plot the error
pl_reconstruction_error = plot(ts, norm.(eachcol(Ȳ-Ŷ)), yaxis = :log, xlabel = "t", ylabel = "L2-Error", label = nothing, color = :red)
pl_missing = plot(pl_reconstruction, pl_reconstruction_error, layout = (2,1))
savefig(pl_missing, joinpath(pwd(), "plots", "$(svname)_missingterm_reconstruction_and_error.pdf"))
pl_overall = plot(pl_trajectory, pl_missing)
savefig(pl_overall, joinpath(pwd(), "plots", "$(svname)_reconstruction.pdf"))
## Symbolic regression via sparse regression ( SINDy based )
# Create a Basis
@variables u[1:2]
u = collect(u)
# Generate the basis functions, multivariate polynomials up to deg 5
# and sine
b = polynomial_basis(u, 5)
basis = Basis(b, u)
g(x) = x[1] < 1 ? Inf : prod(size(X̂))*log(x[1]) + 2x[1]
# Create the thresholds which should be used in the search process
λ = Float32.(exp10.(-7:0.01:0.0))
# Create an optimizer for the SINDy problem
opt = STLSQ(λ)
# Define different problems for the recovery
full_problem = DataDrivenProblem(solution)
ideal_problem = DirectDataDrivenProblem(X̂, Ȳ, t = ts)
real_problem = ContinuousDataDrivenProblem(Xₙ, t)
nn_problem = DirectDataDrivenProblem(X̂, Ŷ, t = ts)
# Test on ideal derivative data for unknown function ( not available )
println("Sparse regression")
sampler = DataSampler(Batcher(n = 1, shuffle = false, repeated = false))
full_res = solve(full_problem, basis, opt, g = g, maxiter = 10000, progress = true)
ideal_res = solve(ideal_problem, basis, opt, g = g, maxiter = 10000, progress = true)
nn_res = solve(nn_problem, basis, opt, g = g, by = :fold, maxiter = 10000, progress = true)
real_res = solve(real_problem, basis, opt, g = g, by = :fold, maxiter = 10000, progress = true)
# Store the results
results = [full_res; ideal_res; nn_res; real_res]
# Show the results
map(println, results)
# Show the results
map(println ∘ result, results)
# Show the identified parameters
map(println ∘ parameter_map, results)
# Define the recovered, hyrid model
function recovered_dynamics!(du,u, p, t)
û = nn_res(u, p) # Network prediction
du[1] = p_[1]*u[1] + û[1]
du[2] = -p_[4]*u[2] + û[2]
end
estimation_prob = ODEProblem(recovered_dynamics!, u0, tspan, parameters(nn_res))
estimate = solve(estimation_prob, Tsit5(), saveat = solution.t)
# Plot
plot(solution)
plot!(estimate)
## Simulation
# Look at long term prediction
t_long = (0.0f0, 50.0f0)
estimation_prob = ODEProblem(recovered_dynamics!, u0, t_long, parameters(nn_res))
estimate_long = solve(estimation_prob, Tsit5(), saveat = 0.1) # Using higher tolerances here results in exit of julia
plot(estimate_long)
true_prob = ODEProblem(lotka!, u0, t_long, p_)
true_solution_long = solve(true_prob, Tsit5(), saveat = estimate_long.t)
plot!(true_solution_long)
## Save the results
save(joinpath(pwd(), "results" ,"$(svname)recovery_$(noise_magnitude).jld2"),
"solution", solution, "X", Xₙ, "t" , ts, "neural_network" , U, "initial_parameters", p, "trained_parameters" , p_trained, # Training
"losses", losses, "result", nn_res, "recovered_parameters", parameters(nn_res), # Recovery
"long_solution", true_solution_long, "long_estimate", estimate_long) # Estimation
## Post Processing and Plots
c1 = 3 # RGBA(174/255,192/255,201/255,1) # Maroon
c2 = :orange # RGBA(132/255,159/255,173/255,1) # Red
c3 = :blue # RGBA(255/255,90/255,0,1) # Orange
c4 = :purple # RGBA(153/255,50/255,204/255,1) # Purple
p1 = plot(t,abs.(Array(solution) .- estimate)' .+ eps(Float32),
lw = 3, yaxis = :log, title = "Timeseries of UODE Error",
color = [3 :orange], xlabel = "t",
label = ["x(t)" "y(t)"],
titlefont = "Helvetica", legendfont = "Helvetica",
legend = :topright)
# Plot L₂
p2 = plot3d(X̂[1,:], X̂[2,:], Ŷ[2,:], lw = 3,
title = "Neural Network Fit of U2(t)", color = c1,
label = "Neural Network", xaxis = "x", yaxis="y",
titlefont = "Helvetica", legendfont = "Helvetica",
legend = :bottomright)
plot!(X̂[1,:], X̂[2,:], Ȳ[2,:], lw = 3, label = "True Missing Term", color=c2)
p3 = scatter(solution, color = [c1 c2], label = ["x data" "y data"],
title = "Extrapolated Fit From Short Training Data",
titlefont = "Helvetica", legendfont = "Helvetica",
markersize = 5)
plot!(p3,true_solution_long, color = [c1 c2], linestyle = :dot, lw=5, label = ["True x(t)" "True y(t)"])
plot!(p3,estimate_long, color = [c3 c4], lw=1, label = ["Estimated x(t)" "Estimated y(t)"])
plot!(p3,[2.99,3.01],[0.0,10.0],lw=1,color=:black, label = nothing)
annotate!([(1.5,13,text("Training \nData", 10, :center, :top, :black, "Helvetica"))])
l = @layout [grid(1,2)
grid(1,1)]
plot(p1,p2,p3,layout = l)
savefig(joinpath(pwd(),"plots","$(svname)full_plot.pdf"))
Try it now.
The issue is the noise. Not necessarily the amount, but the noise model doesn't make sense. It uses the main of the whole trajectory, not the main at the value. This makes it not proportional noise but some non-standard noise model (something which is not biologically plausible). It's not the plausibility that's the issue though. The real issue is that it amplifies the noise around the lower values, the values near zero. If you look at the lower values near zero, you see that some almost cross the line. Because it almost hits zero, there is no possible positive interaction term which can cause this behavior... which makes the training go awry because it's like "oh I need to probably find a very different solution to make this possible".
But in normal scenarios, what one would do is a maximum likelihood fit and recognize that the proportional noise is larger near zero. Larger proportional noise in a maximum likelihood fit equates to doing L2 fits with a weight function, where larger noise leads to a downweight. So basically, the "right" thing to do in a maximum likelihood sense is to multiply the cost function for the small values by a really small number saying "hey, these values are extra bad". You can try this, and I'm sure that would fix it too.
But anyways, the right thing to do would be to fix the error model to actually be proportional noise. It's written and listed as proportional noise, but it's not, since it's proportional to the mean which has very different (and unrealistic) characteristics. Much higher noise values probably work when this is fixed.
Lift-off
Is the
estimated_dynamics!
here https://github.com/ChrisRackauckas/universal_differential_equations/blob/9fd95d7a9f920279beb0733be36bfa96b0cfb3bf/LotkaVolterra/scenario_1.jl#L193supposed to be this
recovered_dynamics!
https://github.com/ChrisRackauckas/universal_differential_equations/blob/9fd95d7a9f920279beb0733be36bfa96b0cfb3bf/LotkaVolterra/scenario_1.jl#L175