EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
443 stars 62 forks source link

Enzyme doesn't work for `AdvancedVI` Part III: `logpdf(MvNormal)` + `Base.Fix1` + `mean(f, x)` = error ??? #1560

Closed Red-Portal closed 3 months ago

Red-Portal commented 3 months ago

Hi, thanks for the quick patches! Looking at the errors, it seems we're not quite there yet, and reproducing the errors in a MWE is getting harder. But without further due, here's a new issue:


using Enzyme, Distributions, LinearAlgebra, SimpleUnPack
Enzyme.API.runtimeActivity!(true)

struct TestNormal{M,S}
    μ::M
    Σ::S
end

function logdensity(model::TestNormal, θ)
    @unpack μ, Σ = model
    logpdf(MvNormal(μ, Σ), θ)
end

function f(params, aux)
    @unpack model = aux
    samples = randn(10, 20) .+ params
    mean(Base.Fix1(logdensity, model), eachcol(samples))
end

function main()
    d = 10
    m = zeros(d)
    C = Diagonal(ones(10))

    aux = (
        model = TestNormal(m, C),
    )

    params = randn(d)

    x = ones(length(params))
    ∇x = zeros(length(params))
    Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(x, ∇x), Enzyme.Const(aux))
    ∇x
end

This results in a segfault. Quite bizzarely, logpdf or mean alone don't trigger an error, but the combination does.

I suspect this is the cause of another non-segfault error caught in the full tests (this is not from the MWE above):

eterminism: Error During Test at /home/krkim/.julia/dev/AdvancedVI/test/inference/repgradelbo_locationscale.jl:61
  Got exception outside of a @test
  AssertionError: Base.allocatedinline(actualRetType) != Base.allocatedinline(literal_rt): actualRetType = Any, literal_rt = Float32, rettype = Active{Float32}
  Stacktrace:
    [1] create_abi_wrapper(enzymefn::LLVM.Function, TT::Type, rettype::Type, actualRetType::Type, Mode::Enzyme.API.CDerivativeMode, augmented::Ptr{Nothing}, width::Int64, returnPrimal::Bool, shadow_init::Bool, world::UInt64, interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter)
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:3825
    [2] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{Bool, Bool, Bool}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{Int64}, boxedArgs::Set{Int64})
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:3692
    [3] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:5832
    [4] codegen
      @ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:5110 [inlined]
    [5] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6639
    [6] _thunk
      @ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6639 [inlined]
    [7] cached_compilation
      @ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6677 [inlined]
    [8] (::Enzyme.Compiler.var"#28587#28588"{DataType, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6746
    [9] JuliaContext(f::Enzyme.Compiler.var"#28587#28588"{DataType, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool, Bool}, Int64, Bool, Bool, UInt64, DataType}; kwargs::@Kwargs{})
      @ GPUCompiler ~/.julia/packages/GPUCompiler/nWT2N/src/driver.jl:52
   [10] JuliaContext(f::Function)
      @ GPUCompiler ~/.julia/packages/GPUCompiler/nWT2N/src/driver.jl:42
   [11] #s2003#28586
      @ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6697 [inlined]
   [12] var"#s2003#28586"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ::Any, ::Type, ::Type, ::Type, tt::Any, ::Type, ::Type, ::Type, ::Type, ::Type, ::Any)
      @ Enzyme.Compiler ./none:0
   [13] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
      @ Core ./boot.jl:602
   [14] runtime_generic_augfwd(activity::Type{Val{(false, true, true)}}, width::Val{1}, ModifiedBetween::Val{(true, true, true)}, RT::Val{@NamedTuple{1, 2, 3}}, f::typeof(AdvancedVI.estimate_energy_with_samples), df::Nothing, primal_1::TestNormal{Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, PDMat{Float32, Matrix{Float32}}}, shadow_1_1::Base.RefValue{TestNormal{Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, PDMat{Float32, Matrix{Float32}}}}, primal_2::Matrix{Float32}, shadow_2_1::Matrix{Float32})
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/aioBJ/src/rules/jitrules.jl:307
   [15] estimate_repgradelbo_ad_forward
      @ ~/.julia/dev/AdvancedVI/src/objectives/elbo/repgradelbo.jl:104 [inlined]
   [16] estimate_repgradelbo_ad_forward
      @ ~/.julia/dev/AdvancedVI/src/objectives/elbo/repgradelbo.jl:0 [inlined]
   [17] augmented_julia_estimate_repgradelbo_ad_forward_28206_inner_1wrap
      @ ~/.julia/dev/AdvancedVI/src/objectives/elbo/repgradelbo.jl:0
   [18] macro expansion
      @ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6587 [inlined]
   [19] enzyme_call
      @ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6188 [inlined]
   [20] AugmentedForwardThunk
      @ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6076 [inlined]
   [21] autodiff(::ReverseMode{true, FFIABI, false}, ::Const{typeof(AdvancedVI.estimate_repgradelbo_ad_forward)}, ::Type{Active}, ::Duplicated{Vector{Float32}}, ::Const{@NamedTuple{rng::StableRNGs.LehmerRNG, obj::RepGradELBO{ClosedFormEntropy}, problem::TestNormal{Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, PDMat{Float32, Matrix{Float32}}}, restructure::Optimisers.Restructure{MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}, @NamedTuple{location::Int64, scale::Int64}}, q_stop::MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}}})
      @ Enzyme ~/.julia/packages/Enzyme/aioBJ/src/Enzyme.jl:253
   [22] autodiff
      @ ~/.julia/packages/Enzyme/aioBJ/src/Enzyme.jl:321 [inlined]
   [23] value_and_gradient!(::AutoEnzyme{Nothing}, f::Function, x::Vector{Float32}, aux::@NamedTuple{rng::StableRNGs.LehmerRNG, obj::RepGradELBO{ClosedFormEntropy}, problem::TestNormal{Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, PDMat{Float32, Matrix{Float32}}}, restructure::Optimisers.Restructure{MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}, @NamedTuple{location::Int64, scale::Int64}}, q_stop::MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}}, out::DiffResults.MutableDiffResult{1, Float32, Tuple{Vector{Float32}}})
      @ AdvancedVIEnzymeExt ~/.julia/dev/AdvancedVI/ext/AdvancedVIEnzymeExt.jl:43
   [24] estimate_gradient!(rng::StableRNGs.LehmerRNG, obj::RepGradELBO{ClosedFormEntropy}, adtype::AutoEnzyme{Nothing}, out::DiffResults.MutableDiffResult{1, Float32, Tuple{Vector{Float32}}}, prob::TestNormal{Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, PDMat{Float32, Matrix{Float32}}}, params::Vector{Float32}, restructure::Optimisers.Restructure{MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}, @NamedTuple{location::Int64, scale::Int64}}, state::Nothing)
      @ AdvancedVI ~/.julia/dev/AdvancedVI/src/objectives/elbo/repgradelbo.jl:121
   [25] optimize(::StableRNGs.LehmerRNG, ::TestNormal{Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, PDMat{Float32, Matrix{Float32}}}, ::RepGradELBO{ClosedFormEntropy}, ::MvLocationScale{LowerTriangular{Float32, Matrix{Float32}}, Normal{Float32}, Vector{Float32}, Float32}, ::Int64; adtype::AutoEnzyme{Nothing}, optimizer::Descent{Float32}, show_progress::Bool, state_init::@NamedTuple{}, callback::Nothing, prog::ProgressMeter.Progress)
      @ AdvancedVI ~/.julia/dev/AdvancedVI/src/optimize.jl:76
   [26] macro expansion
      @ ~/.julia/dev/AdvancedVI/test/inference/repgradelbo_locationscale.jl:63 [inlined]
   [27] macro expansion
      @ /usr/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [28] macro expansion
      @ ~/.julia/dev/AdvancedVI/test/inference/repgradelbo_locationscale.jl:62 [inlined]
   [29] macro expansion
      @ /usr/share/julia/stdlib/v1.10/Test/src/Test.jl:1669 [inlined]
   [30] macro expansion
      @ ~/.julia/dev/AdvancedVI/test/inference/repgradelbo_locationscale.jl:3 [inlined]
   [31] macro expansion
      @ /usr/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [32] top-level scope
      @ ~/.julia/dev/AdvancedVI/test/inference/repgradelbo_locationscale.jl:3
   [33] include(fname::String)
      @ Base.MainInclude ./client.jl:489
   [34] top-level scope
      @ ~/.julia/dev/AdvancedVI/test/runtests.jl:52
   [35] include(fname::String)
      @ Base.MainInclude ./client.jl:489
   [36] top-level scope
      @ none:6
   [37] eval
      @ ./boot.jl:385 [inlined]
   [38] exec_options(opts::Base.JLOptions)
      @ Base ./client.jl:291
   [39] _start()
      @ Base ./client.jl:552

but hard to tell for me. Let me know if they seem related.

wsmoses commented 3 months ago

Well the fact the MWE's are getting harder/less frequent means we're close to them all being done!

Will look at next

wsmoses commented 3 months ago

@Red-Portal fwiw the issue stems from estimate_energy_with_samples apparently not being type inferrable by julia.

something along the lines of

Core.Compiler.return_type(estimate_energy_with_samples, Tuple{TestNormal{Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, PDMat{Float32, Matrix{Float32}}}, Matrix{Float32}}) is spitting out a return of Any

Red-Portal commented 3 months ago

Oops sounds like it's a problem on my side. Let me take a look.

wsmoses commented 3 months ago

@Red-Portal so this passes for me on Enzyme#main (though I had to move the runtimeActivity to right after using Enzyme).

what is the error log?

wsmoses commented 3 months ago

@Red-Portal gentle ping here, what is your environment/how precisely did you run this (since I was unable to repro)

Red-Portal commented 3 months ago

@wsmoses Sorry for missing this! Yeah, I just tried it again, and it works. Thanks for the ping.