ChrisRackauckas / universal_differential_equations

Repository for the Universal Differential Equations for Scientific Machine Learning paper, describing a computational basis for high performance SciML
https://arxiv.org/abs/2001.04385
MIT License
215 stars 59 forks source link

Updated the implementation #40

Closed Abhishek-1Bhatt closed 2 years ago

Abhishek-1Bhatt commented 2 years ago

Is the estimated_dynamics! here https://github.com/ChrisRackauckas/universal_differential_equations/blob/9fd95d7a9f920279beb0733be36bfa96b0cfb3bf/LotkaVolterra/scenario_1.jl#L193

supposed to be this recovered_dynamics! https://github.com/ChrisRackauckas/universal_differential_equations/blob/9fd95d7a9f920279beb0733be36bfa96b0cfb3bf/LotkaVolterra/scenario_1.jl#L175

Abhishek-1Bhatt commented 2 years ago

Or is something missing? Because the p̂ in the first line is also not defined. I get this plot using recovered_dynamics! and parameters(nn_res) as done in this line https://github.com/ChrisRackauckas/universal_differential_equations/blob/9fd95d7a9f920279beb0733be36bfa96b0cfb3bf/LotkaVolterra/scenario_1.jl#L182

download

There are no deprecations in the code now, just the above things

Abhishek-1Bhatt commented 2 years ago

The problem here seems to be not being able to achieve a low loss fit to the neural network. Every time I use ADAM and BFGS it doesn't go below ~0.06. With PolyOpt I was able to achieve around 0.03. If we can fit the neural network better here the overall fit of the DataDrivenProblem could come out better

download .

ChrisRackauckas commented 2 years ago

what tolerance for the ODE solver? Did you try decreasing it at all? And what about the learning rates?

Abhishek-1Bhatt commented 2 years ago

I tried a bunch of different tolerances from 1e-6 to 1e-10 , along with some learning rates for ADAM but the loss won't go down. The same problem occurs with the deprecated version.

ChrisRackauckas commented 2 years ago

What happens in the BFGS steps?

Abhishek-1Bhatt commented 2 years ago

It terminates at around 950 steps with a loss of 0.06. I even tried to run it again with remake but then it won't btake even one step and terminates straightaway.

AlCap23 commented 2 years ago

Using the v1.0 release of this repository, I was able to reproduce the original result of the NN fitting on Julia v1.5.x. v1_5_4 v154_sample2

I suspect the constructor for the NN has changed + the random nummer gen

ChrisRackauckas commented 2 years ago

check the gradients for the same weights.

AlCap23 commented 2 years ago

Do not check out.

initial_gradient_fd.txt

After training:

1_7_training_same_init

I've been able to give a working example, albeit with different IC for the LV system and using MLE:

## Environment and packages
cd(@__DIR__)
using Pkg; Pkg.activate("."); Pkg.instantiate()

using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra
using DiffEqSensitivity

using Optimization
using OptimizationFlux
using OptimizationOptimJL

using Lux
using ComponentArrays
using Plots
gr()
using JLD2, FileIO
using Statistics
using Distributions
using Random
Random.seed!(1234)
#### NOTE
# Since the recent release of DataDrivenDiffEq v0.6.0 where a complete overhaul of the optimizers took
# place, SR3 has been used. Right now, STLSQ performs better and has been changed.

# Create a name for saving ( basically a prefix )
svname = "Scenario_1_"

## Data generation
function lotka!(du, u, p, t)
    α, β, γ, δ = p
    du[1] = α*u[1] - β*u[2]*u[1]
    du[2] = γ*u[1]*u[2]  - δ*u[2]
end

# Define the experimental parameter
tspan = (0.0f0,3.0f0)
u0 = rand(Float32, 2) .* 5.0f0#Float32[0.44249296,4.6280594]
p_ = Float32[1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0,tspan, p_)
solution = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 0.1)

# Ideal data
X = Array(solution)
t = solution.t

# Add noise in terms of the mean
x̄ = mean(X, dims = 2)
noise_magnitude = Float32(5e-2)
Xₙ = X .+ (noise_magnitude*x̄) .* randn(eltype(X), size(X))

plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing])
scatter!(t, transpose(Xₙ), color = :red, label = ["Noisy Data" nothing])

## Define the network
# Gaussian RBF as activation
activation = tanh
# Multilayer FeedForward
U = Lux.Chain(
    Lux.Dense(2,10,activation), Lux.Dense(10,10, activation),Lux.Dense(10,10, activation), Lux.Dense(10,2)
)

# Get the initial parameters
rng = Random.default_rng()
#Random.seed!(rng, 0)

# Parameter and State Variables
ps, st = Lux.setup(rng, U)
p = ComponentVector((;logσ = zeros(Float32, 2), parameters = ps))
# Define the hybrid model
ude_dynamics! = (du,u, p, t) -> let p_true = p_, st = st
    û, _ = U(u, p.parameters, st) # Network prediction
    du[1] = p_true[1]*u[1] + û[1]
    du[2] = -p_true[4]*u[2] + û[2]
end

# Define the problem
prob_nn = ODEProblem(ude_dynamics!, Xₙ[:, 1], tspan, p)

## Function to train the network
# Define a predictor
function predict(θ, X = Xₙ[:,1], T = t)
    solve(prob_nn, Vern7(), u0 = X, p=θ,
                tspan = (T[1], T[end]), saveat = T,
                abstol=1e-5, reltol=1e-5,
                sensealg = ForwardDiffSensitivity(), 
                kwargshandle=KeywordArgSilent
                )
end

predict(p)

# Simple L2 loss
loss = (θ) -> let d = Normal
    X̂ = Array(predict(θ))
    size(X̂) != size(Xₙ) && return Inf
    dists = map(xi->Normal(zero(eltype(θ)), xi), exp.(θ.logσ))
    e = Xₙ .- X̂
    l = zero(eltype(θ))
    for i in 1:size(X̂, 1), j in 1:size(X̂, 2)
        l -= logpdf(dists[i], e[i,j])
    end
    l + 0.001f0*sum(abs2, θ)
end

loss(p)

# Container to track the losses
losses = Float32[]

# Callback to show the loss during training
callback(θ,l) = begin
    push!(losses, l)
    if length(losses)%50==0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    false
end

## Training

# First train with ADAM for better convergence -> move the parameters into a
# favourable starting positing for BFGS
optf = OptimizationFunction((x, p)->loss(x), Optimization.AutoForwardDiff())
optprob = OptimizationProblem(optf, p)
res1 = solve(optprob, ADAM(), callback=callback, maxiters = 200)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
# Train with BFGS
optprob = OptimizationProblem(optf, res1.minimizer)
res2 = solve(optprob, BFGS(initial_stepnorm = 0.001f0), callback=callback, maxiters = 10000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")

# Plot the losses
pl_losses = plot(1:200, losses[1:200], xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "ADAM", color = :blue)
plot!(201:length(losses), losses[201:end],  xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "BFGS", color = :red)
savefig(pl_losses, joinpath(pwd(), "plots", "$(svname)_losses.pdf"))
# Rename the best candidate
p_trained = res2.minimizer

## Analysis of the trained network
# Plot the data and the approximation
ts = first(solution.t):mean(diff(solution.t)):last(solution.t)
sol_ = predict(p_trained, Xₙ[:,1], ts)
X̂ = Array(sol_)
# Trained on noisy data vs real solution

pl_trajectory = plot(ts, transpose(X̂), xlabel = "t", ylabel ="x(t), y(t)", color = :red, label = ["UDE Approximation" nothing])
scatter!(solution.t, transpose(Xₙ), color = :black, label = ["Measurements" nothing])
savefig(pl_trajectory, joinpath(pwd(), "plots", "$(svname)_trajectory_reconstruction.pdf"))

# Ideal unknown interactions of the predictor
Ȳ = [-p_[2]*(X̂[1,:].*X̂[2,:])';p_[3]*(X̂[1,:].*X̂[2,:])']
# Neural network guess
Ŷ, _ = U(X̂, p_trained.parameters, st)

pl_reconstruction = scatter(ts, transpose(Ŷ), xlabel = "t", ylabel ="U(x,y)", color = :red, label = ["UDE Approximation" nothing])
plot!(ts, transpose(Ȳ), color = :black, label = ["True Interaction" nothing])
savefig(pl_reconstruction, joinpath(pwd(), "plots", "$(svname)_missingterm_reconstruction.pdf"))

# Plot the error
pl_reconstruction_error = plot(ts, norm.(eachcol(Ȳ-Ŷ)), yaxis = :log, xlabel = "t", ylabel = "L2-Error", label = nothing, color = :red)
pl_missing = plot(pl_reconstruction, pl_reconstruction_error, layout = (2,1))
savefig(pl_missing, joinpath(pwd(), "plots", "$(svname)_missingterm_reconstruction_and_error.pdf"))
pl_overall = plot(pl_trajectory, pl_missing)
savefig(pl_overall, joinpath(pwd(), "plots", "$(svname)_reconstruction.pdf"))
## Symbolic regression via sparse regression ( SINDy based )

# Create a Basis
@variables u[1:2]
u = collect(u)
# Generate the basis functions, multivariate polynomials up to deg 5
# and sine
b = polynomial_basis(u, 5)
basis = Basis(b, u)

g(x) = x[1] < 1 ? Inf : prod(size(X̂))*log(x[1]) + 2x[1] 
# Create the thresholds which should be used in the search process
λ = Float32.(exp10.(-7:0.01:0.0))
# Create an optimizer for the SINDy problem
opt = STLSQ(λ)
# Define different problems for the recovery
full_problem = DataDrivenProblem(solution)
ideal_problem = DirectDataDrivenProblem(X̂, Ȳ, t = ts)
real_problem = ContinuousDataDrivenProblem(Xₙ, t)
nn_problem = DirectDataDrivenProblem(X̂, Ŷ, t = ts)
# Test on ideal derivative data for unknown function ( not available )
println("Sparse regression")
sampler = DataSampler(Batcher(n = 1, shuffle = false, repeated = false))

full_res = solve(full_problem, basis, opt, g = g, maxiter = 10000, progress = true)
ideal_res = solve(ideal_problem, basis, opt, g = g, maxiter = 10000, progress = true)
nn_res = solve(nn_problem, basis, opt,  g = g,  by = :fold, maxiter = 10000, progress = true)
real_res = solve(real_problem, basis, opt, g = g, by = :fold, maxiter = 10000, progress = true)

# Store the results
results = [full_res; ideal_res; nn_res; real_res]
# Show the results
map(println, results)
# Show the results
map(println ∘ result, results)
# Show the identified parameters
map(println ∘ parameter_map, results)

# Define the recovered, hyrid model
function recovered_dynamics!(du,u, p, t)
    û = nn_res(u, p) # Network prediction
    du[1] = p_[1]*u[1] + û[1]
    du[2] = -p_[4]*u[2] + û[2]
end

estimation_prob = ODEProblem(recovered_dynamics!, u0, tspan, parameters(nn_res))
estimate = solve(estimation_prob, Tsit5(), saveat = solution.t)

# Plot
plot(solution)
plot!(estimate)

## Simulation

# Look at long term prediction
t_long = (0.0f0, 50.0f0)
estimation_prob = ODEProblem(recovered_dynamics!, u0, t_long, parameters(nn_res))
estimate_long = solve(estimation_prob, Tsit5(), saveat = 0.1) # Using higher tolerances here results in exit of julia
plot(estimate_long)

true_prob = ODEProblem(lotka!, u0, t_long, p_)
true_solution_long = solve(true_prob, Tsit5(), saveat = estimate_long.t)
plot!(true_solution_long)

## Save the results
save(joinpath(pwd(), "results" ,"$(svname)recovery_$(noise_magnitude).jld2"),
    "solution", solution, "X", Xₙ, "t" , ts, "neural_network" , U, "initial_parameters", p, "trained_parameters" , p_trained, # Training
    "losses", losses, "result", nn_res, "recovered_parameters", parameters(nn_res), # Recovery
    "long_solution", true_solution_long, "long_estimate", estimate_long) # Estimation

## Post Processing and Plots

c1 = 3 # RGBA(174/255,192/255,201/255,1) # Maroon
c2 = :orange # RGBA(132/255,159/255,173/255,1) # Red
c3 = :blue # RGBA(255/255,90/255,0,1) # Orange
c4 = :purple # RGBA(153/255,50/255,204/255,1) # Purple

p1 = plot(t,abs.(Array(solution) .- estimate)' .+ eps(Float32),
          lw = 3, yaxis = :log, title = "Timeseries of UODE Error",
          color = [3 :orange], xlabel = "t",
          label = ["x(t)" "y(t)"],
          titlefont = "Helvetica", legendfont = "Helvetica",
          legend = :topright)

# Plot L₂
p2 = plot3d(X̂[1,:], X̂[2,:], Ŷ[2,:], lw = 3,
     title = "Neural Network Fit of U2(t)", color = c1,
     label = "Neural Network", xaxis = "x", yaxis="y",
     titlefont = "Helvetica", legendfont = "Helvetica",
     legend = :bottomright)
plot!(X̂[1,:], X̂[2,:], Ȳ[2,:], lw = 3, label = "True Missing Term", color=c2)

p3 = scatter(solution, color = [c1 c2], label = ["x data" "y data"],
             title = "Extrapolated Fit From Short Training Data",
             titlefont = "Helvetica", legendfont = "Helvetica",
             markersize = 5)

plot!(p3,true_solution_long, color = [c1 c2], linestyle = :dot, lw=5, label = ["True x(t)" "True y(t)"])
plot!(p3,estimate_long, color = [c3 c4], lw=1, label = ["Estimated x(t)" "Estimated y(t)"])
plot!(p3,[2.99,3.01],[0.0,10.0],lw=1,color=:black, label = nothing)
annotate!([(1.5,13,text("Training \nData", 10, :center, :top, :black, "Helvetica"))])
l = @layout [grid(1,2)
             grid(1,1)]
plot(p1,p2,p3,layout = l)

savefig(joinpath(pwd(),"plots","$(svname)full_plot.pdf"))
ChrisRackauckas commented 2 years ago

Try it now.

The issue is the noise. Not necessarily the amount, but the noise model doesn't make sense. It uses the main of the whole trajectory, not the main at the value. This makes it not proportional noise but some non-standard noise model (something which is not biologically plausible). It's not the plausibility that's the issue though. The real issue is that it amplifies the noise around the lower values, the values near zero. If you look at the lower values near zero, you see that some almost cross the line. Because it almost hits zero, there is no possible positive interaction term which can cause this behavior... which makes the training go awry because it's like "oh I need to probably find a very different solution to make this possible".

But in normal scenarios, what one would do is a maximum likelihood fit and recognize that the proportional noise is larger near zero. Larger proportional noise in a maximum likelihood fit equates to doing L2 fits with a weight function, where larger noise leads to a downweight. So basically, the "right" thing to do in a maximum likelihood sense is to multiply the cost function for the small values by a really small number saying "hey, these values are extra bad". You can try this, and I'm sure that would fix it too.

But anyways, the right thing to do would be to fix the error model to actually be proportional noise. It's written and listed as proportional noise, but it's not, since it's proportional to the mean which has very different (and unrealistic) characteristics. Much higher noise values probably work when this is fixed.

Abhishek-1Bhatt commented 2 years ago

download

Lift-off