EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
455 stars 64 forks source link

Enzyme doesn't work for `AdvancedVI` Part VII: Type instability with FP32 #1700

Closed Red-Portal closed 3 months ago

Red-Portal commented 3 months ago

Hi! Almost there. Most of the tests passed except for those involving Float32. Here is a MWE:

using Distributions
using DiffResults
using LinearAlgebra
using SimpleUnPack: @unpack
using Functors
using Optimisers
using ADTypes
using Enzyme
using Random, StableRNGs
using FillArrays
using PDMats

struct TestNormal{M,S}
    μ::M
    Σ::S
end

function logdensity(model::TestNormal, θ)
    @unpack μ, Σ = model
    logpdf(MvNormal(μ, Σ), θ)
end

function normal_fullrank(realtype::Type)
    n_dims = 5

    σ0 = realtype(0.3)
    μ  = Fill(realtype(5), n_dims)
    L  = Matrix(σ0*I, n_dims, n_dims)
    Σ  = L*L' |> Hermitian

    TestNormal(μ, PDMat(Σ, Cholesky(L, 'L', 0)))
end

function normal_meanfield(realtype::Type)
    n_dims = 5

    σ0 = realtype(0.3)
    μ  = Fill(realtype(5), n_dims)
    σ  = Fill(σ0, n_dims)

    TestNormal(μ, Diagonal(σ.^2))
end

struct MvLocationScale{
    S, D <: ContinuousDistribution, L, E <: Real
} <: ContinuousMultivariateDistribution
    location ::L
    scale    ::S
    dist     ::D
    scale_eps::E
end

Functors.@functor MvLocationScale (location, scale)

function Distributions.rand(
    rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int
)  where {S, D, L}
    @unpack location, scale, dist = q
    n_dims = length(location)
    scale*rand(rng, dist, n_dims, num_samples) .+ location
end

# This specialization improves AD performance of the sampling path
function Distributions.rand(
    rng::AbstractRNG, q::MvLocationScale{<:Diagonal, D, L}, num_samples::Int
) where {L, D}
    @unpack location, scale, dist = q
    n_dims     = length(location)
    scale_diag = diag(scale)
    scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location
end

restructure_ad_forward(
    ::ADTypes.AutoEnzyme, restructure, params
) = restructure(params)::typeof(restructure.model)

function estimate_repgradelbo_ad_forward(params′, aux)
    @unpack rng, problem, adtype, restructure, q_stop = aux
    q  = restructure_ad_forward(adtype, restructure, params′)
    zs = rand(rng, q, 10)
    mean(Base.Fix1(logdensity, problem), eachcol(zs))
end

function main()
    d      = 5
    adtype = AutoEnzyme()

    seed = (0x38bef07cf9cc549d)
    rng  = StableRNG(seed)

    for T in [Float32, Float64],
        q in [
        MvLocationScale(zeros(T, d), Diagonal(ones(T, d)), Normal{T}(zero(T), one(T)), T(1e-5)),
        MvLocationScale(zeros(T, d), LowerTriangular(Matrix{T}(I, d, d)), Normal{T}(zero(T),one(T)), T(1e-5))
        ], prob in [
            normal_fullrank(T),
            normal_meanfield(T),
        ]
        params, re = Optimisers.destructure(q)
        q_stop     = re(params)
        grad_buf   = DiffResults.DiffResult(zero(eltype(params)), similar(params))

        aux = (
            rng         = rng,
            adtype      = adtype,
            problem     = prob,
            restructure = re,
            q_stop      = q_stop,
        )

        @code_warntype estimate_repgradelbo_ad_forward(params, aux)
        println(typeof(estimate_repgradelbo_ad_forward(params, aux)))

        Enzyme.API.runtimeActivity!(true)
        ∇x = DiffResults.gradient(grad_buf)
        fill!(∇x, zero(eltype(∇x)))
        _, y = Enzyme.autodiff(
            Enzyme.ReverseWithPrimal,
            estimate_repgradelbo_ad_forward,
            Enzyme.Active,
            Enzyme.Duplicated(params, ∇x),
            Enzyme.Const(aux)
        )
    end
end

It's pretty much the same, except for the addition of Float32.

wsmoses commented 3 months ago

Can you post the error message

On Sun, Aug 4, 2024 at 11:23 PM Kyurae Kim @.***> wrote:

Hi! Almost there. Most of the tests passed except for those involving Float32. Here is a MWE:

using Distributionsusing DiffResultsusing LinearAlgebrausing SimpleUnPack: @unpackusing Functorsusing Optimisersusing ADTypesusing Enzymeusing Random, StableRNGsusing FillArraysusing PDMats struct TestNormal{M,S} μ::M Σ::Send function logdensity(model::TestNormal, θ) @unpack μ, Σ = model logpdf(MvNormal(μ, Σ), θ)end function normal_fullrank(realtype::Type) n_dims = 5

σ0 = realtype(0.3)
μ  = Fill(realtype(5), n_dims)
L  = Matrix(σ0*I, n_dims, n_dims)
Σ  = L*L' |> Hermitian

TestNormal(μ, PDMat(Σ, Cholesky(L, 'L', 0)))end

function normal_meanfield(realtype::Type) n_dims = 5

σ0 = realtype(0.3)
μ  = Fill(realtype(5), n_dims)
σ  = Fill(σ0, n_dims)

TestNormal(μ, Diagonal(σ.^2))end

struct MvLocationScale{ S, D <: ContinuousDistribution, L, E <: Real } <: ContinuousMultivariateDistribution location ::L scale ::S dist ::D scale_eps::Eend

@.** MvLocationScale (location, scale) function Distributions.rand( rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int ) where {S, D, L} @unpack location, scale, dist = q n_dims = length(location) scalerand(rng, dist, n_dims, num_samples) .+ locationend

This specialization improves AD performance of the sampling pathfunction Distributions.rand(

rng::AbstractRNG, q::MvLocationScale{<:Diagonal, D, L}, num_samples::Int

) where {L, D} @unpack location, scale, dist = q n_dims = length(location) scale_diag = diag(scale) scale_diag.*rand(rng, dist, n_dims, num_samples) .+ locationend restructure_ad_forward( ::ADTypes.AutoEnzyme, restructure, params ) = restructure(params)::typeof(restructure.model)

function estimate_repgradelbo_ad_forward(params′, aux) @unpack rng, problem, adtype, restructure, q_stop = aux q = restructure_ad_forward(adtype, restructure, params′) zs = rand(rng, q, 10) mean(Base.Fix1(logdensity, problem), eachcol(zs))end function main() d = 5 adtype = AutoEnzyme()

seed = (0x38bef07cf9cc549d)
rng  = StableRNG(seed)

for T in [Float32, Float64],
    q in [
    MvLocationScale(zeros(T, d), Diagonal(ones(T, d)), Normal{T}(zero(T), one(T)), T(1e-5)),
    MvLocationScale(zeros(T, d), LowerTriangular(Matrix{T}(I, d, d)), Normal{T}(zero(T),one(T)), T(1e-5))
    ], prob in [
        normal_fullrank(T),
        normal_meanfield(T),
    ]
    params, re = Optimisers.destructure(q)
    q_stop     = re(params)
    grad_buf   = DiffResults.DiffResult(zero(eltype(params)), similar(params))

    aux = (
        rng         = rng,
        adtype      = adtype,
        problem     = prob,
        restructure = re,
        q_stop      = q_stop,
    )

    @code_warntype estimate_repgradelbo_ad_forward(params, aux)
    println(typeof(estimate_repgradelbo_ad_forward(params, aux)))

    Enzyme.API.runtimeActivity!(true)
    ∇x = DiffResults.gradient(grad_buf)
    fill!(∇x, zero(eltype(∇x)))
    _, y = Enzyme.autodiff(
        Enzyme.ReverseWithPrimal,
        estimate_repgradelbo_ad_forward,
        Enzyme.Active,
        Enzyme.Duplicated(params, ∇x),
        Enzyme.Const(aux)
    )
endend

It's pretty much the same, except for the addition of Float32.

— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Enzyme.jl/issues/1700, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXDWJCBB3DUCHMGHIDLZP3V47AVCNFSM6AAAAABL7ODNW6VHI2DSMVQWIX3LMV43ASLTON2WKOZSGQ2DONJRGI2TENQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>

Red-Portal commented 3 months ago

Oops sorry. The MWE doesn't seem to reproduce. Let me work on this more tomorrow.

Red-Portal commented 3 months ago

@wsmoses Okay I finally got it. Here is the code:

using Distributions
using DiffResults
using LinearAlgebra
using SimpleUnPack: @unpack
using Functors
using Optimisers
using ADTypes
using Enzyme
using Random, StableRNGs
using FillArrays
using PDMats

struct TestNormal{M,S}
    μ::M
    Σ::S
end

function logdensity(model::TestNormal, θ)
    @unpack μ, Σ = model
    logpdf(MvNormal(μ, Σ), θ)
end

function normal_fullrank(realtype::Type)
    n_dims = 5

    σ0 = realtype(0.3)
    μ  = Fill(realtype(5), n_dims)
    L  = Matrix(σ0*I, n_dims, n_dims)
    Σ  = L*L' |> Hermitian

    TestNormal(μ, PDMat(Σ, Cholesky(L, 'L', 0)))
end

function normal_meanfield(realtype::Type)
    n_dims = 5

    σ0 = realtype(0.3)
    μ  = Fill(realtype(5), n_dims)
    σ  = Fill(σ0, n_dims)

    TestNormal(μ, Diagonal(σ.^2))
end

struct MvLocationScale{
    S, D <: ContinuousDistribution, L, E <: Real
} <: ContinuousMultivariateDistribution
    location ::L
    scale    ::S
    dist     ::D
    scale_eps::E
end

Base.length(q::MvLocationScale) = length(q.location)

Base.size(q::MvLocationScale) = size(q.location)

Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D)

Functors.@functor MvLocationScale (location, scale)

function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
    @unpack location, scale, dist = q
    sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale)
end

function Distributions.rand(q::MvLocationScale)
    @unpack location, scale, dist = q
    n_dims = length(location)
    scale*rand(dist, n_dims) + location
end
\

function Distributions.rand(
    rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int
)  where {S, D, L}
    @unpack location, scale, dist = q
    n_dims = length(location)
    scale*rand(rng, dist, n_dims, num_samples) .+ location
end

# This specialization improves AD performance of the sampling path
function Distributions.rand(
    rng::AbstractRNG, q::MvLocationScale{<:Diagonal, D, L}, num_samples::Int
) where {L, D}
    @unpack location, scale, dist = q
    n_dims     = length(location)
    scale_diag = diag(scale)
    scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location
end

restructure_ad_forward(
    ::ADTypes.AutoEnzyme, restructure, params
) = restructure(params)::typeof(restructure.model)

function estimate_repgradelbo_ad_forward(params′, aux)
    @unpack rng, problem, adtype, restructure, q_stop = aux
    q  = restructure_ad_forward(adtype, restructure, params′)
    zs = rand(rng, q, 10)
    mean(Base.Fix1(logdensity, problem), eachcol(zs))
end

function main()
    d      = 5
    adtype = AutoEnzyme()

    seed = (0x38bef07cf9cc549d)
    rng  = StableRNG(seed)

    for T = [Float32,     Float64]
        n_dims       = 10
        μ = randn(T, n_dims)
        L = tril(I + ones(T, n_dims, n_dims)/2) |> LowerTriangular
        Σ = L*L'

        q = MvLocationScale(μ, L, Normal{T}(zero(T),one(T)), T(1e-5))
        q_true = MvNormal(μ, Σ)

        z = rand(q)
        println(logpdf(q, z) ≈ logpdf(q_true, z))
    end

    for T in [Float32, Float64]
        q = MvLocationScale(zeros(T, d), LowerTriangular(Matrix{T}(I, d, d)), Normal{T}(zero(T),one(T)), T(1e-5))
        prob = normal_fullrank(T)

        params, re = Optimisers.destructure(q)
        q_stop     = re(params)
        grad_buf   = DiffResults.DiffResult(zero(eltype(params)), similar(params))

        aux = (
            rng         = rng,
            adtype      = adtype,
            problem     = prob,
            restructure = re,
            q_stop      = q_stop,
        )

        @code_warntype estimate_repgradelbo_ad_forward(params, aux)
        println(typeof(estimate_repgradelbo_ad_forward(params, aux)))

        Enzyme.API.runtimeActivity!(true)
        ∇x = DiffResults.gradient(grad_buf)
        fill!(∇x, zero(eltype(∇x)))
        _, y = Enzyme.autodiff(
            Enzyme.ReverseWithPrimal,
            estimate_repgradelbo_ad_forward,
            Enzyme.Active,
            Enzyme.Duplicated(params, ∇x),
            Enzyme.Const(aux)
        )
    end
end

and the output is

true
true
MethodInstance for estimate_repgradelbo_ad_forward(::Vector{Float32}, ::@NamedTuple{rng::StableRNGs.LehmerRNG, adtype::AutoEnzyme{Nothing}, problem::TestNormal{Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, PDMat{Float32, Matrix{Float32}}}, restructure::Optimisers.Restructure{MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}, @NamedTuple{location::Int64, scale::Int64}}, q_stop::MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}})
  from estimate_repgradelbo_ad_forward(params′, aux) @ Main REPL[28]:1
Arguments
  #self#::Core.Const(estimate_repgradelbo_ad_forward)
  params′::Vector{Float32}
  aux::@NamedTuple{rng::StableRNGs.LehmerRNG, adtype::AutoEnzyme{Nothing}, problem::TestNormal{Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, PDMat{Float32, Matrix{Float32}}}, restructure::Optimisers.Restructure{MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}, @NamedTuple{location::Int64, scale::Int64}}, q_stop::MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}}
Locals
  zs::Matrix{Float32}
  q::MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}
  q_stop::MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}
  restructure::Optimisers.Restructure{MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}, @NamedTuple{location::Int64, scale::Int64}}
  adtype::AutoEnzyme{Nothing}
  problem::TestNormal{Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, PDMat{Float32, Matrix{Float32}}}
  rng::StableRNGs.LehmerRNG
Body::Float32
1 ─       (rng = Base.getproperty(aux, :rng))
│         (problem = Base.getproperty(aux, :problem))
│         (adtype = Base.getproperty(aux, :adtype))
│         (restructure = Base.getproperty(aux, :restructure))
│         (q_stop = Base.getproperty(aux, :q_stop))
│         (q = Main.restructure_ad_forward(adtype, restructure, params′))
│         (zs = Main.rand(rng, q, 10))
│   %8  = Base.Fix1::Core.Const(Base.Fix1)
│   %9  = Main.logdensity::Core.Const(logdensity)
│   %10 = (%8)(%9, problem)::Base.Fix1{typeof(logdensity), TestNormal{Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, PDMat{Float32, Matrix{Float32}}}}
│   %11 = Main.eachcol(zs)::Core.PartialStruct(ColumnSlices{Matrix{Float32}, Tuple{Base.OneTo{Int64}}, SubArray{Float32, 1, Matrix{Float32}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, Any[Matrix{Float32}, Core.Const((Colon(), 1)), Tuple{Base.OneTo{Int64}}])
│   %12 = Main.mean(%10, %11)::Float32
└──       return %12

Float32
ERROR: AssertionError: Base.allocatedinline(actualRetType) returns false: actualRetType = Any, rettype = Active{Any}
Stacktrace:
  [1] 
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:4097
  [2] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:3957
  [3] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:6081
  [4] codegen
    @ ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:5389 [inlined]
  [5] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:6891
  [6] _thunk
    @ ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:6891 [inlined]
  [7] cached_compilation
    @ ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:6932 [inlined]
  [8] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{…}, ::Type{…}, ::Type{…}, tt::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Type{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:7005
  [9] #s2043#28379
    @ ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:7057 [inlined]
 [10] 
    @ Enzyme.Compiler ./none:0
 [11] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [12] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Const{…})
    @ Enzyme ~/.julia/packages/Enzyme/LhGXm/src/Enzyme.jl:309
 [13] autodiff
    @ ~/.julia/packages/Enzyme/LhGXm/src/Enzyme.jl:326 [inlined]
 [14] main()
    @ Main ./REPL[29]:44
 [15] top-level scope
    @ REPL[30]:1
Some type information was truncated. Use `show(err)` to see complete types.

It seems that the line

println(logpdf(q, z) ≈ logpdf(q_true, z))

is causing the bug to happen, which is super weird because it feels like it shouldn't have to do anything with AD.

wsmoses commented 3 months ago

Should be fixed by https://github.com/EnzymeAD/Enzyme.jl/pull/1703