EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
422 stars 58 forks source link

Enzyme issue with component arrays #1451

Closed ChrisRackauckas closed 1 month ago

ChrisRackauckas commented 1 month ago

Found when updating https://github.com/SciML/SciMLSensitivity.jl/pull/1052


using Lux, ComponentArrays, OrdinaryDiffEq, Optimization, OptimizationNLopt,
      OptimizationOptimisers, SciMLSensitivity, Zygote, Plots, Statistics, Random

rng = Random.default_rng()
tspan = (0.0f0, 8.0f0)

ann = Chain(Dense(1, 32, tanh), Dense(32, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(rng, ann)
p = ComponentArray(ps)

θ, _ax = getdata(p), getaxes(p)
const ax = _ax

function dxdt_(dx, x, p, t)
    ps = ComponentArray(p, ax)
    x1, x2 = x
    dx[1] = x[2]
    dx[2] = first(ann([t], ps, st))[1]^3
end
x0 = [-4.0f0, 0.0f0]
ts = Float32.(collect(0.0:0.01:tspan[2]))
prob = ODEProblem(dxdt_, x0, tspan, θ)
solve(prob, Vern9(), abstol = 1e-10, reltol = 1e-10)

function predict_adjoint(θ)
    Array(solve(prob, Vern9(), p = θ, saveat = ts))
end
function loss_adjoint(θ)
    x = predict_adjoint(θ)
    ps = ComponentArray(θ, ax)
    mean(abs2, 4f0 .- x[1, :]) + 2mean(abs2, x[2, :]) +
    mean(abs2, [first(first(ann([t], ps, st))) for t in ts]) / 10
end

l = loss_adjoint(θ)
cb = function (state, l; doplot = true)
    println(l)

    ps = ComponentArray(state.u, ax)

    if doplot
        p = plot(solve(remake(prob, p = state.u), Tsit5(), saveat = 0.01), ylim = (-6, 6), lw = 3)
        plot!(p, ts, [first(first(ann([t], ps, st))) for t in ts], label = "u(t)", lw = 3)
        display(p)
    end

    return false
end

# Setup and run the optimization

loss1 = loss_adjoint(θ)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), adtype)

optprob = Optimization.OptimizationProblem(optf, θ)
res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(0.01), callback = cb, maxiters = 100)

Throws this error

ChrisRackauckas commented 1 month ago

Note that by default it just throws a warning because it's a try/catch:

https://github.com/SciML/SciMLSensitivity.jl/blob/master/src/concrete_solve.jl#L23-L27

You can force the error by changing solve(prob, Vern9(), p = θ, saveat = ts, sensealg = GaussAdjoint(autojacvec=EnzymeVJP()))

avik-pal commented 1 month ago

A more minimal version:

using LuxLib
using ComponentArrays, Random, Enzyme

ps = (; w=rand(5, 5), b=rand(5))

xtest = rand(5, 1)
dx = zeros(size(xtest))

function test_function(x, ps, st)
    y = LuxLib.fused_dense_bias_activation(identity, ps.w, x, ps.b)
    return sum(y)
end

st = NamedTuple()

test_function(xtest, ps, st)

@time autodiff(Reverse, test_function, Active, Duplicated(xtest, dx), Const(ps), Const(st))

ps_ca = ComponentArray(ps)

@time autodiff(
    Reverse, test_function, Active, Duplicated(xtest, dx), Const(ps_ca), Const(st))
wsmoses commented 1 month ago

@avik-pal that minimization was perfect.

I know what this is and I will start working on a fix

wsmoses commented 1 month ago

@ChrisRackauckas the latest jll bump I just pushed should remedy this

wsmoses commented 1 month ago

However the broadcast now requires runtime activity here (which turning on it all runs). I'll see if we can improve alias analysis to fix that here, but at least you should now be unblocked.

wsmoses commented 1 month ago

For the broadcast mixed activity perf: https://fwd.gymni.ch/k03gKb

wsmoses commented 1 month ago

Should now be fixed on main with latest jll bump, please reopen if it persists.