TuringLang / DynamicPPL.jl

Implementation of domain-specific language (DSL) for dynamic probabilistic programming
https://turinglang.org/DynamicPPL.jl/
MIT License
157 stars 26 forks source link

Type inference failure when compiled with custom AbstractInterpreter (e.g. GPUCompiler) #643

Closed wsmoses closed 1 week ago

wsmoses commented 3 months ago
julia> Core.Compiler.typeinf_ext_toplevel(interp, mi)
CodeInfo(
     @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/logdensityfunction.jl:93 within `logdensity`
    ┌ @ Base.jl:37 within `getproperty`
1 ──│ %1  = Base.getfield(f, :varinfo)::TypedVarInfo{@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}
│   └
│   ┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/abstract_varinfo.jl:747 within `unflatten` @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:134 @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:137
│   │┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:116 within `VarInfo`
│   ││┌ @ Base.jl:37 within `getproperty`
│   │││ %2  = Base.getfield(%1, :metadata)::@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}
│   ││└
│   ││┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:158 within `newmetadata`
│   │││┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl within `macro expansion`
│   ││││┌ @ Base.jl:37 within `getproperty`
│   │││││ %3  = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %4  = Base.getfield(%3, :idcs)::Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}
│   │││││ %5  = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %6  = Base.getfield(%5, :vns)::Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}
│   │││││ %7  = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %8  = Base.getfield(%7, :ranges)::Vector{UnitRange{Int64}}
│   │││││ %9  = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %10 = Base.getfield(%9, :ranges)::Vector{UnitRange{Int64}}
│   ││││└
│   ││││ %11 = DynamicPPL.length::typeof(length)
│   ││││┌ @ reducedim.jl:1011 within `sum`
│   │││││┌ @ reducedim.jl:1011 within `#sum#829`
│   ││││││┌ @ reducedim.jl:1015 within `_sum`
│   │││││││┌ @ reducedim.jl:1015 within `#_sum#831`
│   ││││││││ %12 = Base.add_sum::typeof(Base.add_sum)
│   ││││││││┌ @ reducedim.jl:357 within `mapreduce`
│   │││││││││┌ @ reducedim.jl:357 within `#mapreduce#821`
│   ││││││││││┌ @ reducedim.jl:365 within `_mapreduce_dim`
│   │││││││││││ %13 = invoke Base._mapreduce(%11::typeof(length), %12::typeof(Base.add_sum), $(QuoteNode(IndexLinear()))::IndexLinear, %10::Vector{UnitRange{Int64}})::Int64
│   ││││└└└└└└└
│   ││││┌ @ int.jl:87 within `+`
│   │││││ %14 = Base.add_int(0, %13)::Int64
│   ││││└
│   ││││┌ @ range.jl:5 within `Colon`
│   │││││┌ @ range.jl:403 within `UnitRange`
│   ││││││┌ @ range.jl:414 within `unitrange_last`
│   │││││││┌ @ operators.jl:425 within `>=`
│   ││││││││┌ @ int.jl:514 within `<=`
│   │││││││││ %15 = Base.sle_int(1, %14)::Bool
│   │││││││└└
└───│││││││       goto TuringLang/Turing.jl#3 if not %15
2 ──│││││││       goto TuringLang/Turing.jl#4
3 ──│││││││       goto TuringLang/Turing.jl#4
    ││││││└
4 ┄─││││││ %19 = φ (#2 => %14, TuringLang/Turing.jl#3 => 0)::Int64
│   ││││││ %20 = %new(UnitRange{Int64}, 1, %19)::UnitRange{Int64}
└───││││││       goto TuringLang/Turing.jl#5
5 ──││││││       goto TuringLang/Turing.jl#6
    ││││└└
    ││││┌ @ array.jl:973 within `getindex`
6 ──│││││       goto TuringLang/Turing.jl#11 if not true
    │││││┌ @ abstractarray.jl:700 within `checkbounds`
7 ──││││││ %24 = Core.tuple(%20)::Tuple{UnitRange{Int64}}
│   ││││││ @ abstractarray.jl:702 within `checkbounds` @ abstractarray.jl:687
│   ││││││┌ @ abstractarray.jl:389 within `eachindex`
│   │││││││┌ @ abstractarray.jl:137 within `axes1`
│   ││││││││┌ @ abstractarray.jl:98 within `axes`
│   │││││││││┌ @ array.jl:191 within `size`
│   ││││││││││ %25 = Base.arraysize(θ, 1)::Int64
│   │││││││││└
│   │││││││││┌ @ tuple.jl:291 within `map`
│   ││││││││││┌ @ range.jl:469 within `oneto`
│   │││││││││││┌ @ range.jl:467 within `OneTo` @ range.jl:454
│   ││││││││││││┌ @ promotion.jl:532 within `max`
│   │││││││││││││┌ @ int.jl:83 within `<`
│   ││││││││││││││ %26 = Base.slt_int(%25, 0)::Bool
│   │││││││││││││└
│   │││││││││││││┌ @ essentials.jl:647 within `ifelse`
│   ││││││││││││││ %27 = Core.ifelse(%26, 0, %25)::Int64
│   ││││││└└└└└└└└
│   ││││││┌ @ abstractarray.jl:768 within `checkindex`
│   │││││││┌ @ range.jl:672 within `isempty`
│   ││││││││┌ @ operators.jl:378 within `>`
│   │││││││││┌ @ int.jl:83 within `<`
│   ││││││││││ %28 = Base.slt_int(%19, 1)::Bool
│   │││││││└└└
│   │││││││ @ abstractarray.jl:768 within `checkindex` @ abstractarray.jl:763
│   │││││││┌ @ int.jl:86 within `-`
│   ││││││││ %29 = Base.sub_int(1, 1)::Int64
│   │││││││└
│   │││││││┌ @ essentials.jl:524 within `unsigned`
│   ││││││││┌ @ essentials.jl:581 within `reinterpret`
│   │││││││││ %30 = Base.bitcast(UInt64, %29)::UInt64
│   │││││││││ @ essentials.jl:581 within `reinterpret`
│   │││││││││ %31 = Base.bitcast(UInt64, %27)::UInt64
│   │││││││└└
│   │││││││┌ @ int.jl:513 within `<`
│   ││││││││ %32 = Base.ult_int(%30, %31)::Bool
│   │││││││└
│   │││││││┌ @ int.jl:86 within `-`
│   ││││││││ %33 = Base.sub_int(%19, 1)::Int64
│   │││││││└
│   │││││││┌ @ essentials.jl:524 within `unsigned`
│   ││││││││┌ @ essentials.jl:581 within `reinterpret`
│   │││││││││ %34 = Base.bitcast(UInt64, %33)::UInt64
│   │││││││││ @ essentials.jl:581 within `reinterpret`
│   │││││││││ %35 = Base.bitcast(UInt64, %27)::UInt64
│   │││││││└└
│   │││││││┌ @ int.jl:513 within `<`
│   ││││││││ %36 = Base.ult_int(%34, %35)::Bool
│   │││││││└
│   │││││││ @ abstractarray.jl:768 within `checkindex`
│   │││││││┌ @ bool.jl:38 within `&`
│   ││││││││ %37 = Base.and_int(%32, %36)::Bool
│   │││││││└
│   │││││││┌ @ bool.jl:39 within `|`
│   ││││││││ %38 = Base.or_int(%28, %37)::Bool
│   ││││││└└
│   ││││││ @ abstractarray.jl:702 within `checkbounds`
└───││││││       goto TuringLang/Turing.jl#9 if not %38
8 ──││││││       goto TuringLang/Turing.jl#10
9 ──││││││       invoke Base.throw_boundserror(θ::Vector{Float64}, %24::Tuple{UnitRange{Int64}})::Union{}
└───││││││       unreachable
10 ─││││││       nothing::Nothing
    │││││└
    │││││ @ array.jl:974 within `getindex`
    │││││┌ @ range.jl:761 within `length`
    ││││││┌ @ int.jl:86 within `-`
11 ┄│││││││ %44 = Base.sub_int(%19, 1)::Int64
│   ││││││└
│   ││││││┌ @ int.jl:87 within `+`
│   │││││││ %45 = Base.add_int(1, %44)::Int64
│   │││││└└
│   │││││ @ array.jl:975 within `getindex`
│   │││││┌ @ range.jl:706 within `axes`
│   ││││││┌ @ range.jl:761 within `length`
│   │││││││┌ @ int.jl:86 within `-`
│   ││││││││ %46 = Base.sub_int(%19, 1)::Int64
│   │││││││└
│   │││││││┌ @ int.jl:87 within `+`
│   ││││││││ %47 = Base.add_int(1, %46)::Int64
│   ││││││└└
│   ││││││┌ @ range.jl:469 within `oneto`
│   │││││││┌ @ range.jl:467 within `OneTo` @ range.jl:454
│   ││││││││┌ @ promotion.jl:532 within `max`
│   │││││││││┌ @ int.jl:83 within `<`
│   ││││││││││ %48 = Base.slt_int(%47, 0)::Bool
│   │││││││││└
│   │││││││││┌ @ essentials.jl:647 within `ifelse`
│   ││││││││││ %49 = Core.ifelse(%48, 0, %47)::Int64
│   │││││└└└└└
│   │││││┌ @ abstractarray.jl:831 within `similar` @ array.jl:420
│   ││││││┌ @ boot.jl:486 within `Array` @ boot.jl:477
│   │││││││ %50 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Vector{Float64}, svec(Any, Int64), 0, :(:ccall), Vector{Float64}, :(%49), :(%49)))::Vector{Float64}
│   │││││└└
│   │││││ @ array.jl:976 within `getindex`
│   │││││┌ @ operators.jl:378 within `>`
│   ││││││┌ @ int.jl:83 within `<`
│   │││││││ %51 = Base.slt_int(0, %45)::Bool
│   │││││└└
└───│││││       goto TuringLang/Turing.jl#13 if not %51
    │││││ @ array.jl:977 within `getindex`
    │││││┌ @ array.jl:368 within `copyto!`
12 ─││││││       invoke Base._copyto_impl!(%50::Vector{Float64}, 1::Int64, θ::Vector{Float64}, 1::Int64, %45::Int64)::Vector{Float64}
    │││││└
    │││││ @ array.jl:979 within `getindex`
13 ┄│││││       goto TuringLang/Turing.jl#14
    ││││└
    ││││┌ @ Base.jl:37 within `getproperty`
14 ─│││││ %55 = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %56 = Base.getfield(%55, :dists)::Vector{IsoNormal}
│   │││││ %57 = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %58 = Base.getfield(%57, :gids)::Vector{Set{DynamicPPL.Selector}}
│   │││││ %59 = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %60 = Base.getfield(%59, :orders)::Vector{Int64}
│   │││││ %61 = Base.getfield(%2, :x)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   │││││ %62 = Base.getfield(%61, :flags)::Dict{String, BitVector}
│   ││││└
│   ││││┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:47 within `Metadata`
│   │││││ %63 = %new(DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}, %4, %6, %8, %50, %56, %58, %60, %62)::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}
│   ││││└
│   ││││┌ @ boot.jl:622 within `NamedTuple`
│   │││││ %64 = %new(@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, %63)::@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}
│   ││││└
└───││││       goto TuringLang/Turing.jl#15
    ││└└
    ││ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:117 within `VarInfo`
    ││┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:906 within `getlogp`
    │││┌ @ Base.jl:37 within `getproperty`
15 ─││││ %66 = Base.getfield(%1, :logp)::Base.RefValue{Float64}
│   │││└
│   │││┌ @ refvalue.jl:59 within `getindex`
│   ││││┌ @ Base.jl:37 within `getproperty`
│   │││││ %67 = Base.getfield(%66, :x)::Float64
│   ││└└└
│   ││┌ @ refvalue.jl:8 within `RefValue`
│   │││ %68 = %new(Base.RefValue{Float64}, %67)::Base.RefValue{Float64}
│   ││└
│   ││┌ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:923 within `get_num_produce`
│   │││┌ @ Base.jl:37 within `getproperty`
│   ││││ %69 = Base.getfield(%1, :num_produce)::Base.RefValue{Int64}
│   │││└
│   │││┌ @ refvalue.jl:59 within `getindex`
│   ││││┌ @ Base.jl:37 within `getproperty`
│   │││││ %70 = Base.getfield(%69, :x)::Int64
│   ││└└└
│   ││┌ @ refpointer.jl:137 within `Ref`
│   │││┌ @ refvalue.jl:10 within `RefValue` @ refvalue.jl:8
│   ││││ %71 = %new(Base.RefValue{Int64}, %70)::Base.RefValue{Int64}
│   ││└└
│   ││ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:117 within `VarInfo` @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:100
│   ││ %72 = %new(TypedVarInfo{@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}, %64, %68, %71)::TypedVarInfo{@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}
│   ││ @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/varinfo.jl:117 within `VarInfo`
└───││       goto TuringLang/Turing.jl#16
16 ─││       goto TuringLang/Turing.jl#17
17 ─││       goto TuringLang/Turing.jl#18
18 ─││       goto TuringLang/Turing.jl#19
    └└
     @ /home/wmoses/.julia/packages/DynamicPPL/E4kDs/src/logdensityfunction.jl:94 within `logdensity`
19 ─       invoke DynamicPPL.evaluate!!($(QuoteNode(Model{typeof(demo2), (Symbol("##arg#225"),), (), (), Tuple{DynamicPPL.TypeWrap{Matrix{Float64}}}, Tuple{}, DefaultContext}(demo2, (var"##arg#225" = DynamicPPL.TypeWrap{Matrix{Float64}}(),), NamedTuple(), DefaultContext())))::Model{typeof(demo2), (Symbol("##arg#225"),), (), (), Tuple{DynamicPPL.TypeWrap{Matrix{Float64}}}, Tuple{}, DefaultContext}, %72::TypedVarInfo{@NamedTuple{x::DynamicPPL.Metadata{Dict{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{VarName{:x, IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}, $(QuoteNode(DefaultContext()))::DefaultContext)::Union{}
└───       unreachable
)

julia> Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype
Union{}
wsmoses commented 3 months ago

Some relevant code:

using Turing, Enzyme, LinearAlgebra, LogDensityProblems

using AbstractPPL
using DynamicPPL
using Accessors

using GPUCompiler
Enzyme.API.runtimeActivity!(true);

@model function demo2(::Type{TV}=Matrix{Float64}) where {TV}
    d = 2
    n = 2
    x = TV(undef, d, n)
    x[:, 1] ~ MvNormal(zeros(d), I)
    for i = 2:n
        x[:, i] ~ MvNormal(x[:, i - 1], I)
    end
end

model = demo2()
ℓ = Turing.LogDensityFunction(model)
θ = ℓ.varinfo[:]

x = θ

@show LogDensityProblems.logdensity(ℓ, x)

Enzyme.autodiff(ReverseWithPrimal, LogDensityProblems.logdensity, Active, Const(ℓ), Enzyme.Duplicated(x, zero(x)))
World = Base.get_world_counter()
FA = Const{typeof(LogDensityProblems.logdensity)}
A = Active
width = 1
Mode = Enzyme.API.DEM_ReverseModeCombined
ModifiedBetween = (false, false)
ReturnPrimal = true
ShadowInit = false
ABI = Enzyme.FFIABI
TT = Tuple{Const{LogDensityFunction{DynamicPPL.TypedVarInfo{@NamedTuple{x::DynamicPPL.Metadata{Dict{AbstractPPL.VarName{:x, Accessors.IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}, Int64}, Vector{IsoNormal}, Vector{AbstractPPL.VarName{:x, Accessors.IndexLens{Tuple{AbstractPPL.ConcretizedSlice{Int64, Base.OneTo{Int64}}, Int64}}}}, Vector{Float64}, Vector{Set{DynamicPPL.Selector}}}}, Float64}, DynamicPPL.Model{typeof(demo2), (Symbol("##arg#225"),), (), (), Tuple{DynamicPPL.TypeWrap{Matrix{Float64}}}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.DefaultContext}}, Duplicated{Vector{Float64}}}

mi = Enzyme.Compiler.fspec(eltype(FA), TT, World)

target = Enzyme.Compiler.EnzymeTarget()
params = Enzyme.Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, Enzyme.Compiler.remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, Enzyme.Compiler.UnknownTapeType, ABI)
tmp_job    = Enzyme.Compiler.CompilerJob(mi, Enzyme.Compiler.CompilerConfig(target, params; kernel=false), World)

interp = GPUCompiler.get_interpreter(tmp_job)

spec = specialize_method(mi.def, mi.specTypes, mi.sparam_vals)
Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals)
torfjelde commented 3 months ago

So is this a Turing.jl issue or a GPUCompiler.jl issue? Given that the type inference works nicely without GPUCompiler

wsmoses commented 3 months ago

Not sure, but likely both. My understanding of Julia's semantics is that some things are explicitly unspecified and the Julia compiler is free to choose (like where some inlining and type propagation is done). While it's possible GPUComojler is forcing Julia to make different decisions, the fact that it can fail means that Julia is allowed to compile it in a way that guarantees an error and thus is a bug in turing

On Wed, May 29, 2024, 2:54 PM Tor Erlend Fjelde @.***> wrote:

So is this a Turing.jl issue or a GPUCompiler.jl issue? Given that the type inference works nicely without GPUCompiler

— Reply to this email directly, view it on GitHub https://github.com/TuringLang/DynamicPPL.jl/issues/643, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJTUXDQVJUWRE24HNIDLHTZEXFZ7AVCNFSM6AAAAABINIQA52VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMZXGM2DOMRVHA . You are receiving this because you authored the thread.Message ID: @.***>

torfjelde commented 3 months ago

Just to clarify a bit here: there's no "bug" per-se in Turing.jl. The "bug" is just that there's no constructor for MvNormal with eltype Any, but arguably that's not a desirable thing to support. That is, this model is only expected to work when everything is type-stable.

But this constructor is only hit because GPUCompiler somehow causes an inference issue, leading to Any when every other approach correctly infers it as Float64.

wsmoses commented 3 months ago

Sure, I'm not sure which subpackage used by turing the error is caused by.

My guess probalbly is that somewhere there is a use of a typeof(x) [aka inferred type] instead of Core.Typeof(x) [aka runtime type] which would correct the construction.

wsmoses commented 3 months ago

cc @maleadt @vchuravy for visibility

mhauru commented 1 month ago

Gathering links to Enzyme issues that came up when trying to minimise this:

yebai commented 3 weeks ago

@willtebbutt wrote in a slack discussion

I’ve found a fairly simple situation in which the results of inference differ depending on whether you use a Core.Compiler.NativeInterpreter() , or one of the various custom AbstractInterpreter s in use in the wild (e.g. Cthulhu.jl’s , Enzyme.jl’s, and Tapir.jl’s). In particular, the native interpreter successfully infers the return type of Base._mapreduce_dim(Base.Fix1(view, [5.0, 4.0]), vcat, Float64[], [1:1, 2:2], :) to be Vector{Float64, while the other abstract interpreters infer Any.

using Cthulhu, Enzyme, Tapir

# Specify function + args.
fargs = (Base._mapreduce_dim, Base.Fix1(view, [5.0, 4.0]), vcat, Float64[], [1:1, 2:2], :)
tt = typeof(fargs)

# Construct the relevant interpreters.
native_interp = Core.Compiler.NativeInterpreter();
cthulhu_interp = Cthulhu.CthulhuInterpreter();
enzyme_interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(
    Enzyme.Compiler.GLOBAL_REV_CACHE,
    nothing,
    Base.get_world_counter(),
    Enzyme.API.DEM_ReverseModeCombined,
);
tapir_interp = Tapir.TapirInterpreter();

# Both of these correctly infer the return type, Vector{Float64}.
Base.code_typed_by_type(tt; optimize=true, interp=native_interp)
Base.code_ircode_by_type(tt; optimize_until=nothing, interp=native_interp)

# Inference fails.
Base.code_typed_by_type(tt; optimize=true, interp=cthulhu_interp)
Base.code_ircode_by_type(tt; optimize_until=nothing, interp=cthulhu_interp)

# Inference fails.
Base.code_typed_by_type(tt; optimize=true, interp=enzyme_interp)
Base.code_ircode_by_type(tt; optimize_until=nothing, interp=enzyme_interp)

# Inference fails.
Base.code_typed_by_type(tt; optimize=true, interp=tapir_interp)
Base.code_ircode_by_type(tt; optimize_until=nothing, interp=tapir_interp)

@wsmoses pointed out the above compiler bug might be related to this issue.

yebai commented 1 week ago

Not a DynamicPPL/Turing issue; close in favour of https://github.com/JuliaLang/julia/issues/55638