EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
455 stars 63 forks source link

Simplified example of UDE dynamics with Lux neural networks #1057

Closed ChrisRackauckas closed 1 year ago

ChrisRackauckas commented 1 year ago

Boiling https://docs.sciml.ai/Overview/stable/showcase/missing_physics/ down to the Enzyme part, start with:

# SciML Tools
using OrdinaryDiffEq,  SciMLSensitivity
using Optimization, OptimizationOptimisers

# Standard Libraries
using LinearAlgebra, Statistics

# External Libraries
using ComponentArrays, Lux, Zygote, Plots, StableRNGs
gr()

# Set a random seed for reproducible behaviour
const rng = StableRNG(1111)

begin

    function lotka!(du, u, p, t)
        α, β, γ, δ = p
        du[1] = α * u[1] - β * u[2] * u[1]
        du[2] = γ * u[1] * u[2] - δ * u[2]
    end

    # Define the experimental parameter
    tspan = (0.0, 5.0)
    u0 = 5.0f0 * rand(rng, 2)
    p_ = [1.3, 0.9, 0.8, 1.8]
    prob = ODEProblem(lotka!, u0, tspan, p_)
    solution = solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = 0.25)

    # Add noise in terms of the mean
    const X = Array(solution)
    const t = solution.t

    x̄ = mean(X, dims = 2)
    noise_magnitude = 5e-3
    const Xₙ = X .+ (noise_magnitude * x̄) .* randn(rng, eltype(X), size(X))

    plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing])
    scatter!(t, transpose(Xₙ), color = :red, label = ["Noisy Data" nothing])

    rbf(x) = exp.(-(x .^ 2))

    # Define the hybrid model
    function ude_dynamics!(du, u, p, t, p_true)
        û = U(u, p, st)[1] # Network prediction
        du[1] = p_true[1] * u[1] + û[1]
        du[2] = -p_true[4] * u[2] + û[2]
    end

    # Closure with the known parameter
    nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
    # Define the problem

    const prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)

    function predict(θ, X, T)
        _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
        Array(solve(_prob, Vern7(), saveat = T,
                    abstol = 1e-6, reltol = 1e-6,
                    sensealg = InterpolatingAdjoint(autojacvec=EnzymeVJP())))
    end

    function loss(θ)
        X̂ = predict(θ, Xₙ[:, 1], t)
        mean(abs2, Xₙ .- X̂)
    end
end

losses = Float64[]

callback = function (p, l)
    push!(losses, l)
    if length(losses) % 5 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    return false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))

res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 5000)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

Boils down to:

using Enzyme, Lux, Random, ComponentArrays
Enzyme.API.typeWarning!(false)
rng_ = Random.default_rng()

rbf(x) = exp.(-(x .^ 2))
U = Lux.Chain(Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
                Lux.Dense(5, 2))
p, st = Lux.setup(rng_, U)
p = ComponentVector{Float64}(p)

u = 5.0f0 * rand(2)
p_ = [1.3, 0.9, 0.8, 1.8]

function ude_dynamics!(du, u, p, t, p_true)
    û = U(u, p, st)[1] # Network prediction
    du[1] = p_true[1] * u[1] + û[1]
    du[2] = -p_true[4] * u[2] + û[2]
    nothing
end

nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)

du = zero(u)
d_du = zero(u)
d_u = zero(u)
d_p = zero(p)
t = 0.0

Enzyme.autodiff(Reverse, nn_dynamics!, Duplicated(du, d_du), Duplicated(u, d_du), Duplicated(p, d_p), Const(t))

Which gives:

Closest candidates are:
  asprogress(::Any, ::Any, ::Any, ::Any, ::Any, ::Any, ::Any; progress, kwargs...) (method too new to be called from this world context.)
   @ ProgressLogging C:\Users\accou\.julia\packages\ProgressLogging\6KXlp\src\ProgressLogging.jl:156
  asprogress(::Any, ::ProgressLogging.Progress, ::Any...; _...) (method too new to be called from this world context.)
   @ ProgressLogging C:\Users\accou\.julia\packages\ProgressLogging\6KXlp\src\ProgressLogging.jl:155
  asprogress(::Any, ::ProgressLogging.ProgressString, ::Any...; _...) (method too new to be called from this world context.)      
   @ ProgressLogging C:\Users\accou\.julia\packages\ProgressLogging\6KXlp\src\ProgressLogging.jl:200

Stacktrace:
  [1] (::VSCodeServer.var"#60#61"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, VSCodeServer.var"#57#58", Tuple{Base.CoreLogging.LogLevel, String, Module, Symbol, Symbol, String, Int64}, Module})()
    @ VSCodeServer c:\Users\accou\.vscode\extensions\julialang.language-julia-1.51.2\scripts\packages\VSCodeServer\src\progress.jl:56
  [2] #invokelatest#2
    @ .\essentials.jl:819 [inlined]
  [3] invokelatest
    @ .\essentials.jl:816 [inlined]
  [4] #try_process_progress#59
    @ c:\Users\accou\.vscode\extensions\julialang.language-julia-1.51.2\scripts\packages\VSCodeServer\src\progress.jl:55 [inlined]
  [5] try_process_progress
    @ c:\Users\accou\.vscode\extensions\julialang.language-julia-1.51.2\scripts\packages\VSCodeServer\src\progress.jl:51 [inlined]
  [6] handle_message(j::VSCodeServer.VSCodeLogger, level::Base.CoreLogging.LogLevel, message::String, _module::Module, group::Symbol, id::Symbol, file::String, line::Int64; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ VSCodeServer c:\Users\accou\.vscode\extensions\julialang.language-julia-1.51.2\scripts\packages\VSCodeServer\src\progress.jl:6
  [7] handle_message(j::VSCodeServer.VSCodeLogger, level::Base.CoreLogging.LogLevel, message::String, _module::Module, group::Symbol, id::Symbol, file::String, line::Int64)
    @ VSCodeServer c:\Users\accou\.vscode\extensions\julialang.language-julia-1.51.2\scripts\packages\VSCodeServer\src\progress.jl:4
  [8] #invokelatest#2
    @ .\essentials.jl:819 [inlined]
  [9] invokelatest
    @ .\essentials.jl:816 [inlined]
 [10] macro expansion
    @ .\logging.jl:330 [inlined]
 [11] macro expansion
    @ C:\Users\accou\.julia\packages\GPUCompiler\YO8Uj\src\utils.jl:64 [inlined]
 [12] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler C:\Users\accou\.julia\packages\Enzyme\0SYwj\src\compiler.jl:8832
 [13] codegen
    @ C:\Users\accou\.julia\packages\Enzyme\0SYwj\src\compiler.jl:8723 [inlined]
 [14] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)     
    @ Enzyme.Compiler C:\Users\accou\.julia\packages\Enzyme\0SYwj\src\compiler.jl:9671
 [15] _thunk
    @ C:\Users\accou\.julia\packages\Enzyme\0SYwj\src\compiler.jl:9671 [inlined]
 [16] cached_compilation
    @ C:\Users\accou\.julia\packages\Enzyme\0SYwj\src\compiler.jl:9705 [inlined]
 [17] (::Enzyme.Compiler.var"#475#476"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, NTuple{4, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
    @ Enzyme.Compiler C:\Users\accou\.julia\packages\Enzyme\0SYwj\src\compiler.jl:9768
 [18] JuliaContext(f::Enzyme.Compiler.var"#475#476"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, NTuple{4, Bool}, Int64, Bool, Bool, UInt64, DataType})
    @ GPUCompiler C:\Users\accou\.julia\packages\GPUCompiler\YO8Uj\src\driver.jl:47
 [19] #s292#474
    @ C:\Users\accou\.julia\packages\Enzyme\0SYwj\src\compiler.jl:9723 [inlined]
 [20] var"#s292#474"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ::Any, #unused#::Type, #unused#::Type, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler .\none:0
 [21] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core .\boot.jl:602
 [22] runtime_generic_augfwd(activity::Type{Val{(false, false, true, false)}}, width::Val{1}, ModifiedBetween::Val{(true, true, true, true)}, RT::Val{NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3")), Tuple{Any, Any, Any}}}, f::Chain{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), Tuple{Dense{true, typeof(rbf), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(rbf), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(rbf), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}, df::Nothing, primal_1::Vector{Float64}, shadow_1_1::Nothing, primal_2::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:15, Axis(weight = ViewAxis(1:10, ShapedAxis((5, 2), NamedTuple())), bias = ViewAxis(11:15, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(16:45, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5), NamedTuple())), bias = ViewAxis(26:30, ShapedAxis((5, 1), NamedTuple())))), layer_3 = ViewAxis(46:75, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5), NamedTuple())), bias = ViewAxis(26:30, ShapedAxis((5, 1), NamedTuple())))), layer_4 = ViewAxis(76:87, Axis(weight = ViewAxis(1:10, ShapedAxis((2, 5), NamedTuple())), bias = ViewAxis(11:12, ShapedAxis((2, 1), NamedTuple())))))}}}, shadow_2_1::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:15, Axis(weight = ViewAxis(1:10, ShapedAxis((5, 2), NamedTuple())), bias = ViewAxis(11:15, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(16:45, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5), NamedTuple())), bias = ViewAxis(26:30, ShapedAxis((5, 1), NamedTuple())))), layer_3 = ViewAxis(46:75, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5), NamedTuple())), bias = ViewAxis(26:30, ShapedAxis((5, 1), NamedTuple())))), layer_4 = ViewAxis(76:87, Axis(weight = ViewAxis(1:10, ShapedAxis((2, 5), NamedTuple())), bias = ViewAxis(11:12, ShapedAxis((2, 1), NamedTuple())))))}}}, primal_3::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), NTuple{4, NamedTuple{(), Tuple{}}}}, shadow_3_1::Nothing)
    @ Enzyme.Compiler C:\Users\accou\.julia\packages\Enzyme\0SYwj\src\compiler.jl:1361
wsmoses commented 1 year ago

Is this on main?

ChrisRackauckas commented 1 year ago

I didn't test main yet, one second.

ChrisRackauckas commented 1 year ago

On main the issue is just BLAS fallbacks:

┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler C:\Users\accou\.julia\packages\GPUCompiler\cdxtJ\src\utils.jl:56
((nothing, nothing, nothing, nothing),)
ChrisRackauckas commented 1 year ago

With validation:

using Enzyme, Lux, Random, ComponentArrays, ReverseDiff
Enzyme.API.typeWarning!(false)
rng_ = Random.default_rng()

rbf(x) = exp.(-(x .^ 2))
const U = Lux.Chain(Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
                Lux.Dense(5, 2))
p, st = Lux.setup(rng_, U)
const _st = st
p = ComponentVector{Float64}(p)

u = 5.0f0 * rand(2)
const p_ = [1.3, 0.9, 0.8, 1.8]

function ude_dynamics!(du, u, p, t, p_true)
    û = U(u, p, _st)[1] # Network prediction
    du[1] = p_true[1] * u[1] + û[1]
    du[2] = -p_true[4] * u[2] + û[2]
    nothing
end

nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)

du = zero(u)
d_du = zero(u);  d_du[1] = 1
d_u = zero(u);
d_p = zero(p)
t = 0.0

Enzyme.autodiff(Reverse, nn_dynamics!, Duplicated(du, d_du), Duplicated(u, d_u), Duplicated(p, d_p), Const(t))
row1 = copy(d_u);
row1_p = copy(d_p);

du = zero(u)
d_du = zero(u);  d_du[2] = 1;
d_u = zero(u);
d_p = zero(p)
t = 0.0
Enzyme.autodiff(Reverse, nn_dynamics!, Duplicated(du, d_du), Duplicated(u, d_u), Duplicated(p, d_p), Const(t))

enzyme_du = [row1 d_u]'
enzyme_dp = [row1_p d_p]'

function ude_dynamics(u, p, t, p_true)
    û = U(u, p, st)[1] # Network prediction
    [p_true[1] * u[1] + û[1], -p_true[4] * u[2] + û[2]]
end

rd_du, rd_dp = ReverseDiff.jacobian((u, p)) do u, p
    ude_dynamics(u, p, t, p_)
end

rd_du ≈ enzymejac
rd_dp ≈ enzyme_dp

Works on Main, but just hitting BLAS fallbacks.