Closed devmotion closed 4 years ago
Commenting out include("core/ad.jl")
in test/runtests.jl fixes the issue, so it's definitely related to the AD tests.
OK, commenting out https://github.com/TuringLang/Turing.jl/blob/089b106f22a87f56312d525e99d395cbd339456f/test/core/ad.jl#L28 (and the subsequent lines that define and test grad_Turing2
) fixes the issue. So there's some problem with the Zygote backend in Julia 1.4 + master.
@mohamed82008 Do you have any idea what's going on there?
The output related to Turing and DynamicPPL in the stack trace ends with
jl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2322
macro expansion at /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:348 [inlined]
macro expansion at /home/david/.julia/dev/Turing/test/test_utils/models.jl:12 [inlined]
##inner_function#577#43 at /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:493 [inlined]
_pullback at /home/david/.julia/packages/Zygote/KNUTW/src/compiler/interface2.jl:0
adjoint at /home/david/.julia/packages/Zygote/KNUTW/src/lib/lib.jl:167 [inlined]
_pullback at /home/david/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
#_#5 at /home/david/.julia/packages/DynamicPPL/E1QSS/src/model.jl:24 [inlined]
adjoint at /home/david/.julia/packages/Zygote/KNUTW/src/lib/lib.jl:167 [inlined]
I'll see if anything changed on DynamicPPL master.
OK, still the same error on the latest DynamicPPL master (but at least I found a bug in DynamicPPL :stuck_out_tongue:).
The model in this failing test expands to (Turing master and DynamicPPL 0.5.0)
julia> @macroexpand begin
@model gdemo_d() = begin
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
1.5 ~ Normal(m, sqrt(s))
2.0 ~ Normal(m, sqrt(s))
return s, m
end
end
quote
#= REPL[3]:2 =#
begin
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:484 =#
function var"##gdemo_d#397"()
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:485 =#
function var"##inner_function#371"(var"##vi#368"::DynamicPPL.VarInfo, var"##sampler#369"::DynamicPPL.AbstractSampler, var"##ctx#367"::DynamicPPL.AbstractContext, var"##model#370")
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:491 =#
begin
end
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:492 =#
DynamicPPL.resetlogp!(var"##vi#368")
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:493 =#
begin
#= REPL[3]:2 =#
#= REPL[3]:3 =#
begin
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:343 =#
var"##temp_right#373" = InverseGamma(2, 3)
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:344 =#
DynamicPPL.assert_dist(var"##temp_right#373", msg = "Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions on line 340.")
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:345 =#
var"##preprocessed#378" = begin
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:69 =#
var"##sym#398" = Val(:s)
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:72 =#
if !(DynamicPPL.inparams(var"##sym#398", Val{()}())) || DynamicPPL.inparams(var"##sym#398", DynamicPPL.getmissing(var"##model#370"))
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:73 =#
(begin
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/varname.jl:38 =#
DynamicPPL.VarName{:s}("")
end, ())
else
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:75 =#
if DynamicPPL.inparams(var"##sym#398", Val{()}())
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:77 =#
var"##lhs#399" = s
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:78 =#
if var"##lhs#399" === missing
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:79 =#
(begin
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/varname.jl:38 =#
DynamicPPL.VarName{:s}("")
end, ())
else
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:81 =#
var"##lhs#399"
end
else
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:84 =#
throw("This point should not be reached. Please report this error.")
end
end
end
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:346 =#
if var"##preprocessed#378" isa Tuple
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:347 =#
(var"##vn#376", var"##inds#377") = var"##preprocessed#378"
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:348 =#
var"##out#374" = DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#373", var"##vn#376", var"##inds#377", var"##vi#368")
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:349 =#
s = var"##out#374"[1]
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:350 =#
DynamicPPL.acclogp!(var"##vi#368", var"##out#374"[2])
else
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:352 =#
DynamicPPL.acclogp!(var"##vi#368", DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#373", var"##preprocessed#378", var"##vi#368"))
end
end
#= REPL[3]:4 =#
begin
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:343 =#
var"##temp_right#379" = Normal(0, sqrt(s))
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:344 =#
DynamicPPL.assert_dist(var"##temp_right#379", msg = "Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions on line 340.")
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:345 =#
var"##preprocessed#384" = begin
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:69 =#
var"##sym#400" = Val(:m)
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:72 =#
if !(DynamicPPL.inparams(var"##sym#400", Val{()}())) || DynamicPPL.inparams(var"##sym#400", DynamicPPL.getmissing(var"##model#370"))
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:73 =#
(begin
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/varname.jl:38 =#
DynamicPPL.VarName{:m}("")
end, ())
else
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:75 =#
if DynamicPPL.inparams(var"##sym#400", Val{()}())
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:77 =#
var"##lhs#401" = m
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:78 =#
if var"##lhs#401" === missing
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:79 =#
(begin
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/varname.jl:38 =#
DynamicPPL.VarName{:m}("")
end, ())
else
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:81 =#
var"##lhs#401"
end
else
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:84 =#
throw("This point should not be reached. Please report this error.")
end
end
end
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:346 =#
if var"##preprocessed#384" isa Tuple
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:347 =#
(var"##vn#382", var"##inds#383") = var"##preprocessed#384"
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:348 =#
var"##out#380" = DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#379", var"##vn#382", var"##inds#383", var"##vi#368")
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:349 =#
m = var"##out#380"[1]
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:350 =#
DynamicPPL.acclogp!(var"##vi#368", var"##out#380"[2])
else
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:352 =#
DynamicPPL.acclogp!(var"##vi#368", DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#379", var"##preprocessed#384", var"##vi#368"))
end
end
#= REPL[3]:5 =#
begin
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:360 =#
var"##temp_right#385" = Normal(m, sqrt(s))
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:361 =#
DynamicPPL.assert_dist(var"##temp_right#385", msg = "Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions on line 340.")
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:362 =#
DynamicPPL.acclogp!(var"##vi#368", DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#385", 1.5, var"##vi#368"))
end
#= REPL[3]:6 =#
begin
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:360 =#
var"##temp_right#391" = Normal(m, sqrt(s))
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:361 =#
DynamicPPL.assert_dist(var"##temp_right#391", msg = "Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions on line 340.")
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:362 =#
DynamicPPL.acclogp!(var"##vi#368", DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#391", 2.0, var"##vi#368"))
end
#= REPL[3]:7 =#
return (s, m)
end
end
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:495 =#
return DynamicPPL.Model(var"##inner_function#371", NamedTuple(), begin
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:460 =#
DynamicPPL.ModelGen{()}(var"##gdemo_d#397", NamedTuple())
end)
end
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:497 =#
gdemo_d = begin
#= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:460 =#
DynamicPPL.ModelGen{()}(var"##gdemo_d#397", NamedTuple())
end
end
end
Completely unclear to me if it's related but these branches with ... = s
and ... = m
without defining s
and m
before seem fishy. Maybe that's problematic for Zygote?
OK, I manually elided all unnecessary branches
function var"##gdemo_d#397"()
function var"##inner_function#371"(var"##vi#368"::DynamicPPL.VarInfo, var"##sampler#369"::DynamicPPL.AbstractSampler, var"##ctx#367"::DynamicPPL.AbstractContext, var"##model#370")
DynamicPPL.resetlogp!(var"##vi#368")
var"##temp_right#373" = InverseGamma(2, 3)
DynamicPPL.assert_dist(var"##temp_right#373", msg = "Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions on line 340.")
var"##preprocessed#378" = (DynamicPPL.VarName{:s}(""), ())
(var"##vn#376", var"##inds#377") = var"##preprocessed#378"
var"##out#374" = DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#373", var"##vn#376", var"##inds#377", var"##vi#368")
s = var"##out#374"[1]
DynamicPPL.acclogp!(var"##vi#368", var"##out#374"[2])
var"##temp_right#379" = Normal(0, sqrt(s))
DynamicPPL.assert_dist(var"##temp_right#379", msg = "Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions on line 340.")
var"##preprocessed#384" = (DynamicPPL.VarName{:m}(""), ())
(var"##vn#382", var"##inds#383") = var"##preprocessed#384"
var"##out#380" = DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#379", var"##vn#382", var"##inds#383", var"##vi#368")
m = var"##out#380"[1]
DynamicPPL.acclogp!(var"##vi#368", var"##out#380"[2])
var"##temp_right#385" = Normal(m, sqrt(s))
DynamicPPL.assert_dist(var"##temp_right#385", msg = "Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions on line 340.")
DynamicPPL.acclogp!(var"##vi#368", DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#385", 1.5, var"##vi#368"))
var"##temp_right#391" = Normal(m, sqrt(s))
DynamicPPL.assert_dist(var"##temp_right#391", msg = "Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions on line 340.")
DynamicPPL.acclogp!(var"##vi#368", DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#391", 2.0, var"##vi#368"))
return (s, m)
end
return DynamicPPL.Model(var"##inner_function#371", NamedTuple(), DynamicPPL.ModelGen{()}(var"##gdemo_d#397", NamedTuple()))
end
gdemo_d = DynamicPPL.ModelGen{()}(var"##gdemo_d#397", NamedTuple())
and used this model for testing but Julia crashes still. The stacktrace points to the line
var"##out#380" = DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#379", var"##vn#382", var"##inds#383", var"##vi#368")
where m
is sampled.
I don't know how I missed this issue when looking through possibly related Zygote issues: https://github.com/FluxML/Zygote.jl/issues/486 So it seems other people have seen the same error with Zygote before.
The problem is caused by https://github.com/TuringLang/DynamicPPL.jl/blob/d3ff24210ab6d7a0d994839d36148ddfca9f8d2b/src/context_implementations.jl#L135. Without this line the gradient can be computed as expected (I called a custom version in my custom model since of course this line is needed when sampling from the model).
I'm just dreaming but maybe we could avoid these issues with mutation and type instabilities by replacing VarInfo
with a strictly typed structure that's basically an extended NamedTuple
such that the type tells us from which VarName
s we have sampled and of which type their samples are? Instead of push!
ing we would just create a new VarInfo
with one additional sample (like merging/extending a NamedTuple
). Additionally, the samples would not be vectorized and would not be transformed but we would just store the samples as they are and get rid of all the vectorization and transformation problems currently. Transformation or vectorization would just be done when it is actually needed (e.g., when computing HMC steps), since then all information about the samples (such as dimension etc) is available at all times without any hacks. When executing assume
statements with SampleFromPrior
or SampleUniform
we would just dispatch to different methods that either just return the current sample or add a new sample. It seems then every execution of the model could be type stable and computing gradients should be fine with Zygote as well.
Nice investigation! So there seems to be an issue with Zygote and Julia 1.4 here which may be caused by the push!
branch as you investigated. If that's the case, there is a simple workaround here. Wrap push!
in a function and define @nograd
on it. This branch will never be reached when computing gradients anyways.
Your idea on VarInfo
is reasonable for some model classes and would be interesting to experiment with but imagine a model where we loop over the unknown number of random variables in a non-contiguous order and where we need to pause the model and fork for particle samplers. For that, the current approach is probably the best we can do. Although after implementing the proposal in https://github.com/TuringLang/DynamicPPL.jl/issues/4, your suggested design for VarInfo
can become feasible for any model and all non-particle samplers.
Wrap push! in a function and define @nograd on it. This branch will never be reached when computing gradients anyways.
Yes that's what I thought as well.
imagine a model where we loop over the unknown number of random variables in a non-contiguous order and where we need to pause the model and fork for particle samplers. For that, the current approach is probably the best we can do. Although after implementing the proposal in TuringLang/DynamicPPL.jl#4, your suggested design for VarInfo can become feasible for any model and all non-particle samplers.
What exactly do you have in mind? I don't see why an unknown number of random variables would be a problem. If, e.g., there is a stochastic number of random variables, you could just initialize an uninitialized vector for these random variables. Or you could initialize an empty vector and push to that vector in the model definition - in this case the mutation is caused by the user and hence she can't expect Zygote to work.
Consider this:
x = Vector{T}(undef, 10)
for i in 10:-1:1
x[i] ~ Normal()
end
VarInfo
only sees the ~
lines. It doesn't know the size of x
ahead of time. When it sees x[10]
first, it doesn't know if the next element is x[11]
or x[9]
or some other index. So preallocating x
in VarInfo
doesn't work in this case unless we wait untill the end to put the sample in VarInfo
. But this will require the nested model approach to get "the sample" as a NamedTuple
then change it to VarInfo
at the end.
In this case one would just naturally get variables x[10]
, x[9]
, and so on with their sampled values in the named tuple approach. However, if one wants to have only one vector-valued variable x
, one could add the syntax
@variable x
to indicate that x
is a variable. One would have to state that below the initialization of x
, i.e.
x = Vector{T}(undef, 10)
@variable x
then the shape and type of x
is also known.
I just noticed that the CI tests fail on Julia 1.4 + master (see e.g. https://travis-ci.com/github/TuringLang/Turing.jl/builds/158167819 which I reran on the current master branch of Turing) due to
It seems to be AD related?
Googling revealed the following issue that was caused by the same failing assertion: https://github.com/JuliaLang/julia/issues/34247