Closed Red-Portal closed 3 months ago
Can you post the error message
On Sun, Aug 4, 2024 at 11:23 PM Kyurae Kim @.***> wrote:
Hi! Almost there. Most of the tests passed except for those involving Float32. Here is a MWE:
using Distributionsusing DiffResultsusing LinearAlgebrausing SimpleUnPack: @unpackusing Functorsusing Optimisersusing ADTypesusing Enzymeusing Random, StableRNGsusing FillArraysusing PDMats struct TestNormal{M,S} μ::M Σ::Send function logdensity(model::TestNormal, θ) @unpack μ, Σ = model logpdf(MvNormal(μ, Σ), θ)end function normal_fullrank(realtype::Type) n_dims = 5
σ0 = realtype(0.3) μ = Fill(realtype(5), n_dims) L = Matrix(σ0*I, n_dims, n_dims) Σ = L*L' |> Hermitian TestNormal(μ, PDMat(Σ, Cholesky(L, 'L', 0)))end
function normal_meanfield(realtype::Type) n_dims = 5
σ0 = realtype(0.3) μ = Fill(realtype(5), n_dims) σ = Fill(σ0, n_dims) TestNormal(μ, Diagonal(σ.^2))end
struct MvLocationScale{ S, D <: ContinuousDistribution, L, E <: Real } <: ContinuousMultivariateDistribution location ::L scale ::S dist ::D scale_eps::Eend
@.** MvLocationScale (location, scale) function Distributions.rand( rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int ) where {S, D, L} @unpack location, scale, dist = q n_dims = length(location) scalerand(rng, dist, n_dims, num_samples) .+ locationend
This specialization improves AD performance of the sampling pathfunction Distributions.rand(
rng::AbstractRNG, q::MvLocationScale{<:Diagonal, D, L}, num_samples::Int
) where {L, D} @unpack location, scale, dist = q n_dims = length(location) scale_diag = diag(scale) scale_diag.*rand(rng, dist, n_dims, num_samples) .+ locationend restructure_ad_forward( ::ADTypes.AutoEnzyme, restructure, params ) = restructure(params)::typeof(restructure.model)
function estimate_repgradelbo_ad_forward(params′, aux) @unpack rng, problem, adtype, restructure, q_stop = aux q = restructure_ad_forward(adtype, restructure, params′) zs = rand(rng, q, 10) mean(Base.Fix1(logdensity, problem), eachcol(zs))end function main() d = 5 adtype = AutoEnzyme()
seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) for T in [Float32, Float64], q in [ MvLocationScale(zeros(T, d), Diagonal(ones(T, d)), Normal{T}(zero(T), one(T)), T(1e-5)), MvLocationScale(zeros(T, d), LowerTriangular(Matrix{T}(I, d, d)), Normal{T}(zero(T),one(T)), T(1e-5)) ], prob in [ normal_fullrank(T), normal_meanfield(T), ] params, re = Optimisers.destructure(q) q_stop = re(params) grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) aux = ( rng = rng, adtype = adtype, problem = prob, restructure = re, q_stop = q_stop, ) @code_warntype estimate_repgradelbo_ad_forward(params, aux) println(typeof(estimate_repgradelbo_ad_forward(params, aux))) Enzyme.API.runtimeActivity!(true) ∇x = DiffResults.gradient(grad_buf) fill!(∇x, zero(eltype(∇x))) _, y = Enzyme.autodiff( Enzyme.ReverseWithPrimal, estimate_repgradelbo_ad_forward, Enzyme.Active, Enzyme.Duplicated(params, ∇x), Enzyme.Const(aux) ) endend
It's pretty much the same, except for the addition of Float32.
— Reply to this email directly, view it on GitHub https://github.com/EnzymeAD/Enzyme.jl/issues/1700, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXDWJCBB3DUCHMGHIDLZP3V47AVCNFSM6AAAAABL7ODNW6VHI2DSMVQWIX3LMV43ASLTON2WKOZSGQ2DONJRGI2TENQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>
Oops sorry. The MWE doesn't seem to reproduce. Let me work on this more tomorrow.
@wsmoses Okay I finally got it. Here is the code:
using Distributions
using DiffResults
using LinearAlgebra
using SimpleUnPack: @unpack
using Functors
using Optimisers
using ADTypes
using Enzyme
using Random, StableRNGs
using FillArrays
using PDMats
struct TestNormal{M,S}
μ::M
Σ::S
end
function logdensity(model::TestNormal, θ)
@unpack μ, Σ = model
logpdf(MvNormal(μ, Σ), θ)
end
function normal_fullrank(realtype::Type)
n_dims = 5
σ0 = realtype(0.3)
μ = Fill(realtype(5), n_dims)
L = Matrix(σ0*I, n_dims, n_dims)
Σ = L*L' |> Hermitian
TestNormal(μ, PDMat(Σ, Cholesky(L, 'L', 0)))
end
function normal_meanfield(realtype::Type)
n_dims = 5
σ0 = realtype(0.3)
μ = Fill(realtype(5), n_dims)
σ = Fill(σ0, n_dims)
TestNormal(μ, Diagonal(σ.^2))
end
struct MvLocationScale{
S, D <: ContinuousDistribution, L, E <: Real
} <: ContinuousMultivariateDistribution
location ::L
scale ::S
dist ::D
scale_eps::E
end
Base.length(q::MvLocationScale) = length(q.location)
Base.size(q::MvLocationScale) = size(q.location)
Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D)
Functors.@functor MvLocationScale (location, scale)
function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale)
end
function Distributions.rand(q::MvLocationScale)
@unpack location, scale, dist = q
n_dims = length(location)
scale*rand(dist, n_dims) + location
end
\
function Distributions.rand(
rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int
) where {S, D, L}
@unpack location, scale, dist = q
n_dims = length(location)
scale*rand(rng, dist, n_dims, num_samples) .+ location
end
# This specialization improves AD performance of the sampling path
function Distributions.rand(
rng::AbstractRNG, q::MvLocationScale{<:Diagonal, D, L}, num_samples::Int
) where {L, D}
@unpack location, scale, dist = q
n_dims = length(location)
scale_diag = diag(scale)
scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location
end
restructure_ad_forward(
::ADTypes.AutoEnzyme, restructure, params
) = restructure(params)::typeof(restructure.model)
function estimate_repgradelbo_ad_forward(params′, aux)
@unpack rng, problem, adtype, restructure, q_stop = aux
q = restructure_ad_forward(adtype, restructure, params′)
zs = rand(rng, q, 10)
mean(Base.Fix1(logdensity, problem), eachcol(zs))
end
function main()
d = 5
adtype = AutoEnzyme()
seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)
for T = [Float32, Float64]
n_dims = 10
μ = randn(T, n_dims)
L = tril(I + ones(T, n_dims, n_dims)/2) |> LowerTriangular
Σ = L*L'
q = MvLocationScale(μ, L, Normal{T}(zero(T),one(T)), T(1e-5))
q_true = MvNormal(μ, Σ)
z = rand(q)
println(logpdf(q, z) ≈ logpdf(q_true, z))
end
for T in [Float32, Float64]
q = MvLocationScale(zeros(T, d), LowerTriangular(Matrix{T}(I, d, d)), Normal{T}(zero(T),one(T)), T(1e-5))
prob = normal_fullrank(T)
params, re = Optimisers.destructure(q)
q_stop = re(params)
grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params))
aux = (
rng = rng,
adtype = adtype,
problem = prob,
restructure = re,
q_stop = q_stop,
)
@code_warntype estimate_repgradelbo_ad_forward(params, aux)
println(typeof(estimate_repgradelbo_ad_forward(params, aux)))
Enzyme.API.runtimeActivity!(true)
∇x = DiffResults.gradient(grad_buf)
fill!(∇x, zero(eltype(∇x)))
_, y = Enzyme.autodiff(
Enzyme.ReverseWithPrimal,
estimate_repgradelbo_ad_forward,
Enzyme.Active,
Enzyme.Duplicated(params, ∇x),
Enzyme.Const(aux)
)
end
end
and the output is
true
true
MethodInstance for estimate_repgradelbo_ad_forward(::Vector{Float32}, ::@NamedTuple{rng::StableRNGs.LehmerRNG, adtype::AutoEnzyme{Nothing}, problem::TestNormal{Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, PDMat{Float32, Matrix{Float32}}}, restructure::Optimisers.Restructure{MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}, @NamedTuple{location::Int64, scale::Int64}}, q_stop::MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}})
from estimate_repgradelbo_ad_forward(params′, aux) @ Main REPL[28]:1
Arguments
#self#::Core.Const(estimate_repgradelbo_ad_forward)
params′::Vector{Float32}
aux::@NamedTuple{rng::StableRNGs.LehmerRNG, adtype::AutoEnzyme{Nothing}, problem::TestNormal{Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, PDMat{Float32, Matrix{Float32}}}, restructure::Optimisers.Restructure{MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}, @NamedTuple{location::Int64, scale::Int64}}, q_stop::MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}}
Locals
zs::Matrix{Float32}
q::MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}
q_stop::MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}
restructure::Optimisers.Restructure{MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}, @NamedTuple{location::Int64, scale::Int64}}
adtype::AutoEnzyme{Nothing}
problem::TestNormal{Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, PDMat{Float32, Matrix{Float32}}}
rng::StableRNGs.LehmerRNG
Body::Float32
1 ─ (rng = Base.getproperty(aux, :rng))
│ (problem = Base.getproperty(aux, :problem))
│ (adtype = Base.getproperty(aux, :adtype))
│ (restructure = Base.getproperty(aux, :restructure))
│ (q_stop = Base.getproperty(aux, :q_stop))
│ (q = Main.restructure_ad_forward(adtype, restructure, params′))
│ (zs = Main.rand(rng, q, 10))
│ %8 = Base.Fix1::Core.Const(Base.Fix1)
│ %9 = Main.logdensity::Core.Const(logdensity)
│ %10 = (%8)(%9, problem)::Base.Fix1{typeof(logdensity), TestNormal{Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, PDMat{Float32, Matrix{Float32}}}}
│ %11 = Main.eachcol(zs)::Core.PartialStruct(ColumnSlices{Matrix{Float32}, Tuple{Base.OneTo{Int64}}, SubArray{Float32, 1, Matrix{Float32}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}, Any[Matrix{Float32}, Core.Const((Colon(), 1)), Tuple{Base.OneTo{Int64}}])
│ %12 = Main.mean(%10, %11)::Float32
└── return %12
Float32
ERROR: AssertionError: Base.allocatedinline(actualRetType) returns false: actualRetType = Any, rettype = Active{Any}
Stacktrace:
[1]
@ Enzyme.Compiler ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:4097
[2] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:3957
[3] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:6081
[4] codegen
@ ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:5389 [inlined]
[5] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:6891
[6] _thunk
@ ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:6891 [inlined]
[7] cached_compilation
@ ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:6932 [inlined]
[8] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{…}, ::Type{…}, ::Type{…}, tt::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Type{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:7005
[9] #s2043#28379
@ ~/.julia/packages/Enzyme/LhGXm/src/compiler.jl:7057 [inlined]
[10]
@ Enzyme.Compiler ./none:0
[11] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
@ Core ./boot.jl:602
[12] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Const{…})
@ Enzyme ~/.julia/packages/Enzyme/LhGXm/src/Enzyme.jl:309
[13] autodiff
@ ~/.julia/packages/Enzyme/LhGXm/src/Enzyme.jl:326 [inlined]
[14] main()
@ Main ./REPL[29]:44
[15] top-level scope
@ REPL[30]:1
Some type information was truncated. Use `show(err)` to see complete types.
It seems that the line
println(logpdf(q, z) ≈ logpdf(q_true, z))
is causing the bug to happen, which is super weird because it feels like it shouldn't have to do anything with AD.
Should be fixed by https://github.com/EnzymeAD/Enzyme.jl/pull/1703
Hi! Almost there. Most of the tests passed except for those involving
Float32
. Here is a MWE:It's pretty much the same, except for the addition of
Float32
.