EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
422 stars 57 forks source link

Enzyme doesn't work for AdvancedVI #1548

Closed Red-Portal closed 1 week ago

Red-Portal commented 1 week ago

The following snippet, which is the key element of AdvancedVI, doesn't work:

using Enzyme, Functors, Optimisers, Distributions, LinearAlgebra, Random

struct MvLocationScale{
    S, D <: ContinuousDistribution, L
} <: ContinuousMultivariateDistribution
    location ::L
    scale    ::S
    dist     ::D
end

@functor (location, scale)

function Distributions.rand(
    rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int
)  where {S, D, L}
    (; location, scale, dist) = q
    n_dims = length(location)
    scale*rand(rng, dist, n_dims, num_samples) .+ location
end

function f(params, aux)
    (; n_samples, rng, restructure) = aux
    q = restructure(params)
    samples = rand(rng, q, n_samples)
    sum(samples)
end

function main()
    d = 10
    m = zeros(d)
    C = LowerTriangular(Matrix{Float64}(I, d, d))
    q = MvLocationScale(m, C, Normal())
    params, re = Optimisers.destructure(q)
    aux = (
        rng         = Random.default_rng(),
        restructure = re,
        n_samples   = 4,
    )

    println(f(params, aux))

    x = ones(length(params))
    ∇x = zeros(length(params))
    Enzyme.API.runtimeActivity!(true)
    Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(x, ∇x), Enzyme.Const(aux))
    ∇x
end

This yields an error message (quite a mouthful!) that can be found on this pastebin.

Skipping the rand, which means we are only calling restructure as follows:

function f(params, aux)
    (; n_samples, rng, restructure) = aux
    q = restructure(params)
    sum(q.location)
end

results in a different error that can be found here.

wsmoses commented 1 week ago

Looks like you hit our just added [and accidentally released] direct trmm support. This https://github.com/EnzymeAD/Enzyme/pull/1933 should fix the issue on a jll bump.

That fixes the first issue (which incidentally doesn't hit the second once fixed).

The second issue should be fixed by https://github.com/EnzymeAD/Enzyme.jl/pull/1550

wsmoses commented 1 week ago

Merged on main closing. Will make a release once things run

Red-Portal commented 1 week ago

Thanks! I'll give it a go after the new release and re-open if there are issues still looming.

wsmoses commented 1 week ago

started the registration, it'll probably be out in ~1hr or so (but you may need to use julia pkg servers set to eager if they've already checked with the package servers today)

Red-Portal commented 1 week ago

Got a new case, so this issue will have to be re-opened:

using Enzyme, Functors, Optimisers, Distributions, LinearAlgebra, Random, StatsBase

struct MvLocationScale{
    S, D <: ContinuousDistribution, L
} <: ContinuousMultivariateDistribution
    location ::L
    scale    ::S
    dist     ::D
end

@functor MvLocationScale (location, scale)

function StatsBase.entropy(q::MvLocationScale)
    (; location, scale, dist) = q
    n_dims = length(location)
    # `convert` is necessary because `entropy` is not type stable upstream
    n_dims*convert(eltype(location), entropy(dist)) + logdet(scale)
end

function Distributions.rand(
    rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int
)  where {S, D, L}
    (; location, scale, dist) = q
    n_dims = length(location)
    scale*rand(rng, dist, n_dims, num_samples) .+ location
end

function reparam_with_entropy(
    rng      ::Random.AbstractRNG,
    q,
    n_samples::Int,
)
    samples = rand(rng, q, n_samples)
    samples, entropy(q)
end

function f(params, aux)
    (; n_samples, rng, restructure) = aux
    q = restructure(params)
    samples, ent = reparam_with_entropy(rng, q, 10)
    sum(samples) + ent
end

function main()
    d = 10
    m = zeros(d)
    C = LowerTriangular(Matrix{Float64}(I, d, d))
    q = MvLocationScale(m, C, Normal())
    params, re = Optimisers.destructure(q)
    aux = (
        rng         = Random.default_rng(),
        restructure = re,
        n_samples   = 4,
    )

    println(f(params, aux))

    x = ones(length(params))
    ∇x = zeros(length(params))
    Enzyme.API.runtimeActivity!(true)
    Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(x, ∇x), Enzyme.Const(aux))
    ∇x
end

This spits out:

ERROR: Enzyme execution failed.
Enzyme could not find shadow for value

 Inverted pointers: 
available inversion for { {} addrspace(10)*, [1 x {} addrspace(10)*], [2 x double] } addrspace(11)* %0 of { {} addrspace(10)*, [1 x {} addrspace(10)*], [2 x double] } addrspace(11)* %"'"
available inversion for   %6 = getelementptr inbounds { {} addrspace(10)*, [1 x {} addrspace(10)*], [2 x double] }, { {} addrspace(10)*, [1 x {} addrspace(10)*], [2 x double] } addrspace(11)* %0, i64 0, i32 1, !dbg !55 of   %"'ipg" = getelementptr inbounds { {} addrspace(10)*, [1 x {} addrspace(10)*], [2 x double] }, { {} addrspace(10)*, [1 x {} addrspace(10)*], [2 x double] } addrspace(11)* %"'", i64 0, i32 1, !dbg !68

cannot find shadow for   %5 = call fastcc nonnull {} addrspace(10)* @julia_rand_8346({ {} addrspace(10)*, [1 x {} addrspace(10)*], [2 x double] } addrspace(11)* noalias nocapture nofree noundef nonnull readonly align 8 dereferenceable(32) %0, i64 signext %1) #51, !dbg !54, !noalias !42

Caused by:
Stacktrace:
 [1] reparam_with_entropy
   @ ./REPL[65]:6
 [2] reparam_with_entropy
   @ ./REPL[65]:0

Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/RsJic/src/compiler.jl:1612
  [2] reparam_with_entropy
    @ ./REPL[65]:0 [inlined]
  [3] augmented_julia_reparam_with_entropy_8321_inner_1wrap
    @ ./REPL[65]:0
  [4] macro expansion
    @ ~/.julia/packages/Enzyme/RsJic/src/compiler.jl:6587 [inlined]
  [5] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::MixedDuplicated{…}, ::Const{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/RsJic/src/compiler.jl:6188
  [6] (::Enzyme.Compiler.AugmentedForwardThunk{…})(::Const{…}, ::Const{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/RsJic/src/compiler.jl:6076
  [7] runtime_generic_augfwd(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(reparam_with_entropy), df::Nothing, primal_1::TaskLocalRNG, shadow_1_1::Nothing, primal_2::MvLocationScale{…}, shadow_2_1::Base.RefValue{…}, primal_3::Int64, shadow_3_1::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/RsJic/src/rules/jitrules.jl:311
  [8] f
    @ ./REPL[66]:4 [inlined]
  [9] f
    @ ./REPL[66]:0 [inlined]
 [10] augmented_julia_f_8166_inner_1wrap
    @ ./REPL[66]:0
 [11] macro expansion
    @ ~/.julia/packages/Enzyme/RsJic/src/compiler.jl:6587 [inlined]
 [12] enzyme_call
    @ ~/.julia/packages/Enzyme/RsJic/src/compiler.jl:6188 [inlined]
 [13] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/RsJic/src/compiler.jl:6076 [inlined]
 [14] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Const{…})
    @ Enzyme ~/.julia/packages/Enzyme/RsJic/src/Enzyme.jl:253
 [15] autodiff
    @ ~/.julia/packages/Enzyme/RsJic/src/Enzyme.jl:321 [inlined]
 [16] main()
    @ Main ./REPL[67]:18
 [17] top-level scope
    @ REPL[68]:1
Some type information was truncated. Use `show(err)` to see complete types.
wsmoses commented 1 week ago

Can you open that as a separate issue so we can track it?

Will work on that next