TuringLang / DynamicPPL.jl

Implementation of domain-specific language (DSL) for dynamic probabilistic programming
https://turinglang.org/DynamicPPL.jl/
MIT License
157 stars 26 forks source link

Issue with ReverseDiff and undef on Julia v1.7 #612

Open torfjelde opened 3 months ago

torfjelde commented 3 months ago

Issue

Julia v1.7's implementation of copy! is not compatible with undef entries, which causes quite a few issues for us when we're using BangBang.jl. Ref: https://github.com/JuliaFolds2/BangBang.jl/issues/21 https://github.com/JuliaFolds2/BangBang.jl/pull/22 and their references.

After https://github.com/JuliaFolds2/BangBang.jl/pull/22, a lot of these issues were addressed, but it seems it still doesn't quite do it when we're working with AD backends which uses types for the tracing, e.g. ReverseDiff.jl.

In particular, something like

@model function demo(::Type{TV}=Vector{Float64}) where {TV}
    x = TV(undef, 1)
    x[1] ~ Normal()
end

will fail when used with ReverseDiff.jl, due hitting the NoBang version, and thus hitting https://github.com/JuliaFolds2/BangBang.jl/issues/21 (even after https://github.com/JuliaFolds2/BangBang.jl/issues/22). The problem comes down to

eltype(x)::ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}
typeof(__varinfo__[@varname(x[1])])::ReverseDiff.TrackedReal{Float64, Float64, Nothing}

i.e. eltype(x) !== typeof(__varinfo__[@varname(x[1])]), and so the method instances https://github.com/JuliaFolds2/BangBang.jl/blob/1e4455451378d150a0359e56d5e7ed75b74ddd6a/src/base.jl#L531-L539 is not hit when using ReverseDiff.jl.

In contrast, this is not an issue for, say, ForwardDiff.jl or non-diff stuff, since here we always hit the mutating version, i.e. the check above is hit.

Things to do

torfjelde commented 3 months ago

This was the debug script I was using:

using DynamicPPL, ReverseDiff, LogDensityProblems, LogDensityProblemsAD, Distributions

s_global = nothing
s_global_i = nothing

@model function demo_assume_index_observe(
    x=[1.5, 2.0], ::Type{TV}=Vector{Float64}
) where {TV}
    # `assume` with indexing and `observe`
    s = TV(undef, length(x))
    global s_global
    s_global = s
    for i in eachindex(s)
        if haskey(__varinfo__, @varname(s[i]))
            global s_global_i
            s_global_i = __varinfo__[@varname(s[i])]
            @info "s[$i] varinfo" s_global_i

            @info "s[$i] check" BangBang.possible(BangBang._setindex!, s, s_global_i, 1)
        end
        @info "s[$i]" isassigned(s, i) typeof(s) eltype(s)
        s[i] ~ InverseGamma(2, 3)
    end
    m = TV(undef, length(x))
    for i in eachindex(m)
        @info "m[$i]" isassigned(m, i) eltype(m)
        m[i] ~ Normal(0, sqrt(s[i]))
    end
    x ~ MvNormal(m, Diagonal(s))

    return (; s=s, m=m, x=x, logp=DynamicPPL.getlogp(__varinfo__))
end

# (✓) WORKS!
model = demo_assume_index_observe()
f = ADgradient(AutoReverseDiff(), DynamicPPL.LogDensityFunction(model))
LogDensityProblems.logdensity_and_gradient(f, f.ℓ.varinfo[:])

# (×) BREAKS!
# NOTE: This requires the set up from `test/mcmc/abstractmcmc.jl`.
model = demo_assume_index_observe()
adtype = AutoForwardDiff()
sampler = initialize_nuts(model)
sampler_ext = DynamicPPL.Sampler(externalsampler(sampler; adtype, unconstrained=true), model)
sample(model, sampler_ext, 2; n_adapts=0, discard_initial=0)

### Error message:
# [1] getindex
#   @ ./array.jl:861 [inlined]
# [2] copy!(dst::Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, src::Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}})
#   @ Base ./abstractarray.jl:874
# [3] _setindex
#   @ ~/.julia/packages/BangBang/2pcna/src/NoBang/base.jl:133 [inlined]
# [4] may
#   @ ~/.julia/packages/BangBang/2pcna/src/core.jl:11 [inlined]
# [5] setindex!!
#   @ ~/.julia/packages/BangBang/2pcna/src/base.jl:478 [inlined]
# [6] set
#   @ ~/.julia/packages/BangBang/2pcna/src/accessors.jl:35 [inlined]

Note the very strange behavior of working if I try to set up the only the gradient computation by hand, while when I try to sample using the same adtype, we hit the issue! Seems to be something wrt. how the different variables are created.