SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
870 stars 157 forks source link

usage of universal Differential Algebraic Equations #842

Closed ghost closed 1 year ago

ghost commented 1 year ago

Hello,

I have been trying to create a universal Differential Algebraic Equation (I want to enforce some physical constraints).

There is a test case in here. However, I cannot run it, so I tried a couple of modifications to make it work.

Here is what I got:

script 1

using DiffEqFlux, OrdinaryDiffEq, Test, Plots, Lux, StableRNGs, ComponentArrays, Sundials

using Optimization, OptimizationOptimisers, OptimizationOptimJL

function f!(du, u, p, t)
    y₁,y₂,y₃ = u
    k₁,k₂,k₃ = p
    du[1] = -k₁*y₁ + k₃*y₂*y₃
    du[2] =  k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2
    du[3] =  y₁ + y₂ + y₃ - 1
    nothing
end

u₀ = [1.0, 0, 0]

M = [1. 0  0
     0  1. 0
     0  0  0]

tspan = (0.0, 10.0)
p_true = [0.04, 3e7, 1e4]

func = ODEFunction(f!, mass_matrix=M)
prob = ODEProblem(func, u₀, tspan, p_true)
sol = solve(prob, Rodas5(), saveat=0.1, 
            abstol = 1e-9, 
            reltol = 1e-9 )

t_true = sol.t

# Neural Network
rng = StableRNG(1111);
U = Lux.Chain(Lux.Dense(3, 64, tanh), Lux.Dense(64, 2))
p, st = Lux.setup(rng, U)
p = ComponentArray(p)

function ude_dynamics!(du, u, p, t, p_true)
    NN = U(u, p, st)[1]
    du[1] = NN[1];
    du[2] = NN[2]
    du[3] = u[1] + u[2] + u[3] - 1.0;
end

# Closure with the known parameters
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_true);

prob_nn = ODEProblem(ODEFunction(nn_dynamics!, mass_matrix = M), u₀ , tspan, p);

function predict(θ, X = u₀, T = t_true)
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)

    Array(solve(_prob, Rodas5(), saveat = T,
                abstol = 1e-4, reltol = 1e-4, verbose=false))
end

function loss(p)
    pred = predict(p)
    loss = sum(abs2,sol .- pred)
    loss,pred
end

losses = Float64[];

callback = function (p, l, pred)
    push!(losses, l)
    if length(losses) % 10 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
        println("Current min/max after $(length(losses)) iterations: $(extrema(pred[1, :] + pred[2, :] + pred[3, :]))")
    end
    return false
end

optfunc = Optimization.OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote())
optprob = Optimization.OptimizationProblem(optfunc, p)
res = Optimization.solve(  optprob,
                            ADAM(), 
                            callback = callback, 
                            maxiters = 100
                         )

In the context of uDAE I created a script based on uODE with the lotka-volterra system:

script 2

using DiffEqFlux, OrdinaryDiffEq, Test, Plots, Lux, StableRNGs, ComponentArrays

using Optimization, OptimizationOptimisers, OptimizationOptimJL

function f!(du, u, p, t)
    y₁,y₂,y₃ = u
    k₁,k₂,k₃ = p
    du[1] = -k₁*y₁ + k₃*y₂*y₃
    du[2] =  k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2
    du[3] =  y₁ + y₂ + y₃ - 1
    nothing
end

u₀ = [1.0, 0, 0]

M = [1. 0  0
     0  1. 0
     0  0  0]

tspan = (0.0, 10.0)
p_true = [0.04, 3e7, 1e4]

func = ODEFunction(f!, mass_matrix=M)
prob = ODEProblem(func, u₀, tspan, p_true)
sol = solve(prob, Rodas5(), saveat=0.1, 
            abstol = 1e-9, 
            reltol = 1e-9 )

t_true = sol.t

# Neural Network
rng = StableRNG(1111);
U = Lux.Chain(Lux.Dense(3, 64, tanh), Lux.Dense(64, 2))
p, st = Lux.setup(rng, U)
p = ComponentArray(p)

function ude_dynamics!(du, u, p, t, p_true)
    NN = U(u, p, st)[1]
    du[1] = NN[1];
    du[2] = NN[2]
    du[3] = u[1] + u[2] + u[3] - 1.0;
end

# Closure with the known parameters
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_true);

prob_nn = ODEProblem(ODEFunction(nn_dynamics!, mass_matrix = M), u₀ , tspan, p);

function predict(θ, X = u₀, T = t_true)
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)

    Array(solve(_prob, Rodas5(), saveat = T,
                abstol = 1e-4, reltol = 1e-4, verbose=false))
end

function loss(p)
    pred = predict(p)
    loss = sum(abs2,sol .- pred)
    loss,pred
end

losses = Float64[];

callback = function (p, l, pred)
    push!(losses, l)
    if length(losses) % 10 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
        println("Current min/max after $(length(losses)) iterations: $(extrema(pred[1, :] + pred[2, :] + pred[3, :]))")
    end
    return false
end

optfunc = Optimization.OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote())
optprob = Optimization.OptimizationProblem(optfunc, p)
res = Optimization.solve(  optprob,
                            ADAM(), 
                            callback = callback, 
                            maxiters = 100
                         )

The end result of both scripts is similar but not exactly the same (don't know way, but I guess some floating point round-off)

However, the computational time is very high. Did I do something wrong that is costing a lot of resources? Or are uDAE expensive by themselves? What can I do to make the code faster?

Best Regards

ChrisRackauckas commented 1 year ago

There is a test case in here. However, I cannot run it, so I tried a couple of modifications to make it work.

That test passes on latest versions, so just make sure you're on latest (Julia v1.9 with latest DiffEqFlux and SciMLSensitivity). I just ran the test suite and it went fine.

The end result of both scripts is similar but not exactly the same (don't know way, but I guess some floating point round-off)

Solving to 1e-4 accuracy locally is about 1e-3 - 1e-2 globally each step of an optimization for 100 steps of an optimization, so digits of accuracy each step of an optimization. Yeah that's not going to be the most stable. If you need more stability then lower the tolerances.

However, the computational time is very high. Did I do something wrong that is costing a lot of resources? Or are uDAE expensive by themselves? What can I do to make the code faster?

Using Rodas5 will be quite expensive here with this choice of adjoint. Using sensealg=GaussAdjoint() (which just merged today) should be a lot faster for this use case. Also, you may want to look into using FBDF() as the solver here. Both should cut the cost down a lot in the adjoint pass.

But there doesn't seem to be anything actionable here, so I'm closing it. Feel free to keep asking questions, though for usage questions non-bug reports we recommend using the Discourse https://discourse.julialang.org/