Open torfjelde opened 3 months ago
This was the debug script I was using:
using DynamicPPL, ReverseDiff, LogDensityProblems, LogDensityProblemsAD, Distributions
s_global = nothing
s_global_i = nothing
@model function demo_assume_index_observe(
x=[1.5, 2.0], ::Type{TV}=Vector{Float64}
) where {TV}
# `assume` with indexing and `observe`
s = TV(undef, length(x))
global s_global
s_global = s
for i in eachindex(s)
if haskey(__varinfo__, @varname(s[i]))
global s_global_i
s_global_i = __varinfo__[@varname(s[i])]
@info "s[$i] varinfo" s_global_i
@info "s[$i] check" BangBang.possible(BangBang._setindex!, s, s_global_i, 1)
end
@info "s[$i]" isassigned(s, i) typeof(s) eltype(s)
s[i] ~ InverseGamma(2, 3)
end
m = TV(undef, length(x))
for i in eachindex(m)
@info "m[$i]" isassigned(m, i) eltype(m)
m[i] ~ Normal(0, sqrt(s[i]))
end
x ~ MvNormal(m, Diagonal(s))
return (; s=s, m=m, x=x, logp=DynamicPPL.getlogp(__varinfo__))
end
# (✓) WORKS!
model = demo_assume_index_observe()
f = ADgradient(AutoReverseDiff(), DynamicPPL.LogDensityFunction(model))
LogDensityProblems.logdensity_and_gradient(f, f.ℓ.varinfo[:])
# (×) BREAKS!
# NOTE: This requires the set up from `test/mcmc/abstractmcmc.jl`.
model = demo_assume_index_observe()
adtype = AutoForwardDiff()
sampler = initialize_nuts(model)
sampler_ext = DynamicPPL.Sampler(externalsampler(sampler; adtype, unconstrained=true), model)
sample(model, sampler_ext, 2; n_adapts=0, discard_initial=0)
### Error message:
# [1] getindex
# @ ./array.jl:861 [inlined]
# [2] copy!(dst::Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, src::Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}})
# @ Base ./abstractarray.jl:874
# [3] _setindex
# @ ~/.julia/packages/BangBang/2pcna/src/NoBang/base.jl:133 [inlined]
# [4] may
# @ ~/.julia/packages/BangBang/2pcna/src/core.jl:11 [inlined]
# [5] setindex!!
# @ ~/.julia/packages/BangBang/2pcna/src/base.jl:478 [inlined]
# [6] set
# @ ~/.julia/packages/BangBang/2pcna/src/accessors.jl:35 [inlined]
Note the very strange behavior of working if I try to set up the only the gradient computation by hand, while when I try to sample using the same adtype, we hit the issue! Seems to be something wrt. how the different variables are created.
Issue
Julia v1.7's implementation of
copy!
is not compatible withundef
entries, which causes quite a few issues for us when we're using BangBang.jl. Ref: https://github.com/JuliaFolds2/BangBang.jl/issues/21 https://github.com/JuliaFolds2/BangBang.jl/pull/22 and their references.After https://github.com/JuliaFolds2/BangBang.jl/pull/22, a lot of these issues were addressed, but it seems it still doesn't quite do it when we're working with AD backends which uses types for the tracing, e.g. ReverseDiff.jl.
In particular, something like
will fail when used with ReverseDiff.jl, due hitting the
NoBang
version, and thus hitting https://github.com/JuliaFolds2/BangBang.jl/issues/21 (even after https://github.com/JuliaFolds2/BangBang.jl/issues/22). The problem comes down toi.e.
eltype(x) !== typeof(__varinfo__[@varname(x[1])])
, and so the method instances https://github.com/JuliaFolds2/BangBang.jl/blob/1e4455451378d150a0359e56d5e7ed75b74ddd6a/src/base.jl#L531-L539 is not hit when using ReverseDiff.jl.In contrast, this is not an issue for, say, ForwardDiff.jl or non-diff stuff, since here we always hit the mutating version, i.e. the check above is hit.
Things to do
setindex!
, i.e. the mutating version, is indeed valid in this case, so BangBang.jl is incorrectly reporting it as not. I'll raise a corresponding issue over there, but it's also somewhat unclear to me if we'll ever be able to fully cover all the scenarios correctly, or if we'll have to play this "catch up"-game indefinitively. I.e. may be worth considering implementing a slightly more stringent version ofBangBang.AccessorsImpl.prefermutation
which always usessetindex!
in favour ofBangBang.NoBang.setindex
whenever we see an array.