Closed ChrisRackauckas closed 1 year ago
Is this on main?
I didn't test main yet, one second.
On main the issue is just BLAS fallbacks:
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler C:\Users\accou\.julia\packages\GPUCompiler\cdxtJ\src\utils.jl:56
((nothing, nothing, nothing, nothing),)
With validation:
using Enzyme, Lux, Random, ComponentArrays, ReverseDiff
Enzyme.API.typeWarning!(false)
rng_ = Random.default_rng()
rbf(x) = exp.(-(x .^ 2))
const U = Lux.Chain(Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
Lux.Dense(5, 2))
p, st = Lux.setup(rng_, U)
const _st = st
p = ComponentVector{Float64}(p)
u = 5.0f0 * rand(2)
const p_ = [1.3, 0.9, 0.8, 1.8]
function ude_dynamics!(du, u, p, t, p_true)
û = U(u, p, _st)[1] # Network prediction
du[1] = p_true[1] * u[1] + û[1]
du[2] = -p_true[4] * u[2] + û[2]
nothing
end
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
du = zero(u)
d_du = zero(u); d_du[1] = 1
d_u = zero(u);
d_p = zero(p)
t = 0.0
Enzyme.autodiff(Reverse, nn_dynamics!, Duplicated(du, d_du), Duplicated(u, d_u), Duplicated(p, d_p), Const(t))
row1 = copy(d_u);
row1_p = copy(d_p);
du = zero(u)
d_du = zero(u); d_du[2] = 1;
d_u = zero(u);
d_p = zero(p)
t = 0.0
Enzyme.autodiff(Reverse, nn_dynamics!, Duplicated(du, d_du), Duplicated(u, d_u), Duplicated(p, d_p), Const(t))
enzyme_du = [row1 d_u]'
enzyme_dp = [row1_p d_p]'
function ude_dynamics(u, p, t, p_true)
û = U(u, p, st)[1] # Network prediction
[p_true[1] * u[1] + û[1], -p_true[4] * u[2] + û[2]]
end
rd_du, rd_dp = ReverseDiff.jacobian((u, p)) do u, p
ude_dynamics(u, p, t, p_)
end
rd_du ≈ enzymejac
rd_dp ≈ enzyme_dp
Works on Main, but just hitting BLAS fallbacks.
Boiling https://docs.sciml.ai/Overview/stable/showcase/missing_physics/ down to the Enzyme part, start with:
Boils down to:
Which gives: