Wrong gradient when using `bitcode_replacement!(false)` in neural ODE #1269

using Lux, ComponentArrays, OrdinaryDiffEq, Optimization, OptimizationNLopt,
    OptimizationOptimisers, SciMLSensitivity, Zygote, Plots, Statistics, Random
using Enzyme; Enzyme.Compiler.bitcode_replacement!(false)
rng = Random.default_rng()
tspan = (0.0f0, 8.0f0)

ann = Chain(Dense(1, 32, tanh), Dense(32, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(rng, ann)
p = ComponentArray(ps)

θ, ax = getdata(p), getaxes(p)

function dxdt_(dx, x, p, t)
    ps = ComponentArray(p, ax)
    x1, x2 = x
    dx[1] = x[2]
    dx[2] = first(ann([t], ps, st))[1]^3
x0 = [-4.0f0, 0.0f0]
ts = Float32.(collect(0.0:0.01:tspan[2]))
prob = ODEProblem(dxdt_, x0, tspan, θ)
solve(prob, Vern9(), abstol = 1e-10, reltol = 1e-10)

function predict_adjoint(θ)
    Array(solve(prob, Vern9(), p = θ, saveat = ts,
        sensealg = InterpolatingAdjoint(autojacvec = EnzymeVJP())))
function loss_adjoint(θ)
    x = predict_adjoint(θ)
    ps = ComponentArray(θ, ax)
    mean(abs2, 4.0 .- x[1, :]) + 2mean(abs2, x[2, :]) +
    mean(abs2, [first(first(ann([t], ps, st))) for t in ts]) / 10

l = loss_adjoint(θ)

Prints warning: ** On entry to SGEMV parameter number 6 had an illegal value ** Adapted from https://docs.sciml.ai/SciMLSensitivity/stable/examples/optimal_control/optimal_control/

wsmoses commented 7 months ago

Can you post your Julia and enzyme version? If the Bitcode flag was needed to be passed that means you were on an earlier version before this was marked non experimental and thus it may have been fixed since.

wsmoses commented 7 months ago

Can you also isolate this to just the Enzyme autodiff call without the wrappers

ArnoStrouwen commented 7 months ago

I think this is approximately what is going on inside SciMLSensitivity:

using Lux, ComponentArrays, OrdinaryDiffEq, SciMLSensitivity, Statistics, Random
using Enzyme; Enzyme.Compiler.bitcode_replacement!(false)
rng = Random.default_rng()
tspan = (0.0f0, 8.0f0)

ann = Chain(Dense(1, 32, tanh), Dense(32, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(rng, ann)
p = ComponentArray(ps)

θ, ax = getdata(p), getaxes(p)

function dxdt_(dx, x, p, t)
    ps = ComponentArray(p, ax)
    x1, x2 = x
    dx[1] = x[2]
    dx[2] = first(ann([t], ps, st))[1]^3
x0 = [-4.0f0, 0.0f0]
ts = Float32.(collect(0.0:0.01:tspan[2]))

dx = zero(x0)
function adfunc(out, u, _p, t)
    dxdt_(out, u, _p, t)
Enzyme.autodiff(Enzyme.Reverse, adfunc, Enzyme.Duplicated(dx, copy(x0)),
    Enzyme.Duplicated(copy(x0), zero(x0)), Enzyme.Duplicated(copy(θ), zero(θ)), Enzyme.Const(ts[1]))
(Enzyme) pkg> st
Status `~/SciML/SciMLSensitivity.jl/Enzyme/Project.toml`
  [b0b7db55] ComponentArrays v0.15.8
  [7da242da] Enzyme v0.11.14
  [b2108857] Lux v0.5.14
  [1dea7af3] OrdinaryDiffEq v6.70.1
  [1ed8b502] SciMLSensitivity v7.55.0
  [9a3f8284] Random
  [10745b16] Statistics v1.10.0

julia> versioninfo()
Julia Version 1.10.0
Commit 3120989f39b (2023-12-25 18:01 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 24 × AMD Ryzen 9 5900X 12-Core Processor
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver3)
  Threads: 1 on 24 virtual cores
  JULIA_PKG_DEVDIR = /home/arno/SciML/
wsmoses commented 7 months ago

And for sake of understanding, what is the expected result here, that is not being computed correctly?

ArnoStrouwen commented 7 months ago

For the complete example, in the original post of this issue, is that the gradient is different depending on the bitcode flag.

For the paired down example, I don't know what could be the issue, nothing seems immediately wrong to me in the Duplicated vectors, but I don't have much Enzyme experience.

I reduced the example, such that it still gives the output: ** On entry to SGEMV parameter number 6 had an illegal value **.

Perhaps I paired it down too much, besides this autodiff call, there is a more complicated one also present in SciMLSensitivity, where dxdt_ gets a different wrapper, https://github.com/SciML/SciMLSensitivity.jl/blob/master/src/adjoint_common.jl#L430-L455 and this wrapper then gets Duplicated with make_zero: https://github.com/SciML/SciMLSensitivity.jl/blob/master/src/adjoint_common.jl#L201C47-L201C100 https://github.com/SciML/SciMLSensitivity.jl/blob/master/src/derivative_wrappers.jl#L696C63-L696C77

wsmoses commented 7 months ago

In order to debug this properly we'll need an example:

wsmoses commented 7 months ago

Reduced to :

using Enzyme; Enzyme.Compiler.bitcode_replacement!(false)

using LinearAlgebra
ps = zeros(Float32, 30, 1)
function adfunc(ps)
    out = Vector{Float32}(undef, 30)
    @inline LinearAlgebra.BLAS.gemv!('N', true, ps, [0.0f0], false, out)
    return out[1]

Enzyme.autodiff(Enzyme.Reverse, adfunc, Enzyme.Duplicated(deepcopy(ps), deepcopy(ps)))
wsmoses commented 7 months ago
julia> Enzyme.autodiff(Enzyme.Reverse, adfunc, Enzyme.Duplicated(deepcopy(ps), deepcopy(ps)))
after simplification :
@ZuseZ4 anything pop out?

wsmoses commented 7 months ago
using Enzyme; Enzyme.Compiler.bitcode_replacement!(false)

using LinearAlgebra
ps = zeros(Float32, 30, 1)
function adfunc(A)
    Y = Vector{Float32}(undef, 30)
    X = [1.0f0]

    trans = 'N'
    m,n = 30, 1
    lda = 1
    pX, sX = pointer(X), 1
    pY, sY = pointer(Y), 1
    pA = pointer(A)
    lda = 30
    alpha = 1.0f0
    beta = 0.0f0
    GC.@preserve A X Y ccall((:sgemv_64_, LinearAlgebra.BLAS.libblastrampoline), Cvoid,
        (Ref{UInt8}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{Float32},
         Ptr{Float32}, Ref{LinearAlgebra.BLAS.BlasInt}, Ptr{Float32}, Ref{LinearAlgebra.BLAS.BlasInt},
         Ref{Float32}, Ptr{Float32}, Ref{LinearAlgebra.BLAS.BlasInt}, Clong),
         trans, 30, 1, alpha,
         pA, lda, pX, sX,
         beta, pY, sY, 1)

    return @inbounds Y[1]

Enzyme.autodiff(Enzyme.Reverse, adfunc, Enzyme.Duplicated(deepcopy(ps), deepcopy(ps)))
wsmoses commented 7 months ago
julia> Enzyme.autodiff(Enzyme.Reverse, adfunc, Enzyme.Duplicated(deepcopy(ps), deepcopy(ps)))
after simplification :
wsmoses commented 7 months ago

Should be fixed by https://github.com/EnzymeAD/Enzyme.jl/pull/1281

please reopen if not.

ArnoStrouwen commented 7 months ago

It now produces the correct result with bitcode replacement on and off. However, I am a bit surprised that the allocations are exactly the same in both versions: on:

julia> @btime Zygote.gradient(loss_adjoint,θ)
┌ Warning: Using fallback BLAS replacements for (["ssymv_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
  1.089 s (5063232 allocations: 674.09 MiB)
(Float32[-48.12045, 96.89185, 5.4492106, -136.30328, -277.6249, -2.9152653, 159.34677, -252.21376, -168.57451, 95.22521  …  28.876875, 58.53126, -94.83481, 123.85488, 202.57362, 72.3266, -231.3183, -164.42274, -63.517776, -324.779],)


julia> @btime Zygote.gradient(loss_adjoint,θ)
  1.113 s (5063232 allocations: 674.09 MiB)
(Float32[-48.12045, 96.89185, 5.4492106, -136.30328, -277.6249, -2.9152653, 159.34677, -252.21376, -168.57451, 95.22521  …  28.876875, 58.53126, -94.83481, 123.85488, 202.57362, 72.3266, -231.3183, -164.42274, -63.517776, -324.779],)

If different BLAS code is used, would you not expect at least some difference in allocations?

ZuseZ4 commented 7 months ago

As far as I know Julia tooling measures allocations on a higher level, so such low-level allocations won't be caught when using the rules. I assume the same holds for the fallback.

wsmoses commented 7 months ago

As of the latest release, for reverse mode, using bitcode replacement has no effect for dot/gemm/gemv/etc, all of these will use tablegen rather than fallback blas