TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2.04k stars 219 forks source link

CI tests fail on Julia 1.4 + master #1199

Closed devmotion closed 4 years ago

devmotion commented 4 years ago

I just noticed that the CI tests fail on Julia 1.4 + master (see e.g. https://travis-ci.com/github/TuringLang/Turing.jl/builds/158167819 which I reran on the current master branch of Turing) due to

julia: /buildworker/worker/package_linux64/build/src/codegen.cpp:4357: jl_cgval_t emit_expr(jl_codectx_t&, jl_value_t*, ssize_t): Assertion `token.V->getType()->isTokenTy()' failed.

signal (6): Aborted

in expression starting at /home/travis/build/TuringLang/Turing.jl/test/core/ad.jl:16

It seems to be AD related?

Googling revealed the following issue that was caused by the same failing assertion: https://github.com/JuliaLang/julia/issues/34247

devmotion commented 4 years ago

Commenting out include("core/ad.jl") in test/runtests.jl fixes the issue, so it's definitely related to the AD tests.

devmotion commented 4 years ago

OK, commenting out https://github.com/TuringLang/Turing.jl/blob/089b106f22a87f56312d525e99d395cbd339456f/test/core/ad.jl#L28 (and the subsequent lines that define and test grad_Turing2) fixes the issue. So there's some problem with the Zygote backend in Julia 1.4 + master.

devmotion commented 4 years ago

@mohamed82008 Do you have any idea what's going on there?

devmotion commented 4 years ago

The output related to Turing and DynamicPPL in the stack trace ends with

jl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2322
macro expansion at /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:348 [inlined]
macro expansion at /home/david/.julia/dev/Turing/test/test_utils/models.jl:12 [inlined]
##inner_function#577#43 at /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:493 [inlined]
_pullback at /home/david/.julia/packages/Zygote/KNUTW/src/compiler/interface2.jl:0
adjoint at /home/david/.julia/packages/Zygote/KNUTW/src/lib/lib.jl:167 [inlined] 
_pullback at /home/david/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
#_#5 at /home/david/.julia/packages/DynamicPPL/E1QSS/src/model.jl:24 [inlined]
adjoint at /home/david/.julia/packages/Zygote/KNUTW/src/lib/lib.jl:167 [inlined]

I'll see if anything changed on DynamicPPL master.

devmotion commented 4 years ago

OK, still the same error on the latest DynamicPPL master (but at least I found a bug in DynamicPPL :stuck_out_tongue:).

devmotion commented 4 years ago

The model in this failing test expands to (Turing master and DynamicPPL 0.5.0)

julia> @macroexpand begin                                                                                                                                                            
       @model gdemo_d() = begin                         
         s ~ InverseGamma(2, 3)                                                                                                                                                      
         m ~ Normal(0, sqrt(s))                                                                                                                                                      
         1.5 ~ Normal(m, sqrt(s))                                                                                                                                                    
         2.0 ~ Normal(m, sqrt(s))         
         return s, m                                                                                                                                                                 
       end                                                            
       end                                      
quote                               
    #= REPL[3]:2 =#                                                                                                                                                                  
    begin                                                                            
        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:484 =#                                                                                                       
        function var"##gdemo_d#397"()                     
            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:485 =#                                                                                                   
            function var"##inner_function#371"(var"##vi#368"::DynamicPPL.VarInfo, var"##sampler#369"::DynamicPPL.AbstractSampler, var"##ctx#367"::DynamicPPL.AbstractContext, var"##model#370")
                #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:491 =#                                                                                               
                begin                             
                end                                                                                                                                                                  
                #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:492 =#
                DynamicPPL.resetlogp!(var"##vi#368")    
                #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:493 =#
                begin                                                                                                                                                                
                    #= REPL[3]:2 =#                       
                    #= REPL[3]:3 =#        
                    begin               
                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:343 =#               
                        var"##temp_right#373" = InverseGamma(2, 3)                                                                                                                   
                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:344 =#
                        DynamicPPL.assert_dist(var"##temp_right#373", msg = "Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions on line 340.")
                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:345 =#
                        var"##preprocessed#378" = begin                                                                                                                              
                                #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:69 =#
                                var"##sym#398" = Val(:s)                                                                                                                             
                                #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:72 =#
                                if !(DynamicPPL.inparams(var"##sym#398", Val{()}())) || DynamicPPL.inparams(var"##sym#398", DynamicPPL.getmissing(var"##model#370"))
                                    #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:73 =#                                                              
                                    (begin                                                                                                                                           
                                            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/varname.jl:38 =#
                                            DynamicPPL.VarName{:s}("")                                                                                                               
                                        end, ())                                 
                                else
                                    #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:75 =#
                                    if DynamicPPL.inparams(var"##sym#398", Val{()}())                                                                                                
                                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:77 =#
                                        var"##lhs#399" = s
                                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:78 =#
                                        if var"##lhs#399" === missing
                                            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:79 =#
                                            (begin                
                                                    #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/varname.jl:38 =#
                                                    DynamicPPL.VarName{:s}("")                                                                                                       
                                                end, ())                                                                                                                             
                                        else                                                                                                                                         
                                            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:81 =#
                                            var"##lhs#399"
                                        end
                                    else                                                                                                                                             
                                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:84 =#
                                        throw("This point should not be reached. Please report this error.")
                                    end                                                                                                                                              
                                end                                                                                                                                                  
                            end                                                                                                                                                      
                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:346 =#
                        if var"##preprocessed#378" isa Tuple
                            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:347 =#
                            (var"##vn#376", var"##inds#377") = var"##preprocessed#378"
                            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:348 =#
                            var"##out#374" = DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#373", var"##vn#376", var"##inds#377", var"##vi#368")
                            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:349 =#
                            s = var"##out#374"[1]                                                                                                                                    
                            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:350 =#
                            DynamicPPL.acclogp!(var"##vi#368", var"##out#374"[2])
                        else
                            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:352 =#
                            DynamicPPL.acclogp!(var"##vi#368", DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#373", var"##preprocessed#378", var"##vi#368"))
                        end                                                           
                    end                                                  
                    #= REPL[3]:4 =#
                    begin
                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:343 =#
                        var"##temp_right#379" = Normal(0, sqrt(s))

                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:344 =#
                        DynamicPPL.assert_dist(var"##temp_right#379", msg = "Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions on line 340.")
                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:345 =#
                        var"##preprocessed#384" = begin
                                #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:69 =#
                                var"##sym#400" = Val(:m)
                                #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:72 =#
                                if !(DynamicPPL.inparams(var"##sym#400", Val{()}())) || DynamicPPL.inparams(var"##sym#400", DynamicPPL.getmissing(var"##model#370"))
                                    #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:73 =#
                                    (begin
                                            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/varname.jl:38 =#
                                            DynamicPPL.VarName{:m}("")
                                        end, ())
                                else
                                    #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:75 =#
                                    if DynamicPPL.inparams(var"##sym#400", Val{()}())
                                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:77 =#
                                        var"##lhs#401" = m
                                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:78 =#
                                        if var"##lhs#401" === missing
                                            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:79 =#
                                            (begin
                                                    #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/varname.jl:38 =#
                                                    DynamicPPL.VarName{:m}("")
                                                end, ())
                                        else
                                            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:81 =#
                                            var"##lhs#401"
                                        end
                                    else
                                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:84 =#
                                        throw("This point should not be reached. Please report this error.")
                                    end
                                end
                            end
                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:346 =#
                        if var"##preprocessed#384" isa Tuple
                            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:347 =#
                            (var"##vn#382", var"##inds#383") = var"##preprocessed#384"
                            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:348 =#
                            var"##out#380" = DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#379", var"##vn#382", var"##inds#383", var"##vi#368")
                            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:349 =#
                            m = var"##out#380"[1]
                            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:350 =#
                            DynamicPPL.acclogp!(var"##vi#368", var"##out#380"[2])
                        else
                            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:352 =#
                            DynamicPPL.acclogp!(var"##vi#368", DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#379", var"##preprocessed#384", var"##vi#368"))
                        end
                    end
                    #= REPL[3]:5 =#
                    begin
                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:360 =#
                        var"##temp_right#385" = Normal(m, sqrt(s))
                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:361 =#
                        DynamicPPL.assert_dist(var"##temp_right#385", msg = "Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions on line 340.")
                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:362 =#
                        DynamicPPL.acclogp!(var"##vi#368", DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#385", 1.5, var"##vi#368"))
                    end
                    #= REPL[3]:6 =#
                    begin
                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:360 =#
                        var"##temp_right#391" = Normal(m, sqrt(s))
                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:361 =#
                        DynamicPPL.assert_dist(var"##temp_right#391", msg = "Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions on line 340.")
                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:362 =#
                        DynamicPPL.acclogp!(var"##vi#368", DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#391", 2.0, var"##vi#368"))
                    end
                    #= REPL[3]:7 =#
                    return (s, m)
                end
            end
            #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:495 =#
            return DynamicPPL.Model(var"##inner_function#371", NamedTuple(), begin
                        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:460 =#
                        DynamicPPL.ModelGen{()}(var"##gdemo_d#397", NamedTuple())
                    end)
        end
        #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:497 =#
        gdemo_d = begin
                #= /home/david/.julia/packages/DynamicPPL/E1QSS/src/compiler.jl:460 =#
                DynamicPPL.ModelGen{()}(var"##gdemo_d#397", NamedTuple())
            end
    end
end

Completely unclear to me if it's related but these branches with ... = s and ... = m without defining s and m before seem fishy. Maybe that's problematic for Zygote?

devmotion commented 4 years ago

OK, I manually elided all unnecessary branches

function var"##gdemo_d#397"()
    function var"##inner_function#371"(var"##vi#368"::DynamicPPL.VarInfo, var"##sampler#369"::DynamicPPL.AbstractSampler, var"##ctx#367"::DynamicPPL.AbstractContext, var"##model#370")
        DynamicPPL.resetlogp!(var"##vi#368")

        var"##temp_right#373" = InverseGamma(2, 3)
        DynamicPPL.assert_dist(var"##temp_right#373", msg = "Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions on line 340.")

        var"##preprocessed#378" = (DynamicPPL.VarName{:s}(""), ())
        (var"##vn#376", var"##inds#377") = var"##preprocessed#378"
        var"##out#374" = DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#373", var"##vn#376", var"##inds#377", var"##vi#368")
        s = var"##out#374"[1]
        DynamicPPL.acclogp!(var"##vi#368", var"##out#374"[2])

        var"##temp_right#379" = Normal(0, sqrt(s))
        DynamicPPL.assert_dist(var"##temp_right#379", msg = "Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions on line 340.")

        var"##preprocessed#384" = (DynamicPPL.VarName{:m}(""), ())
        (var"##vn#382", var"##inds#383") = var"##preprocessed#384"
        var"##out#380" = DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#379", var"##vn#382", var"##inds#383", var"##vi#368")
        m = var"##out#380"[1]
        DynamicPPL.acclogp!(var"##vi#368", var"##out#380"[2])

        var"##temp_right#385" = Normal(m, sqrt(s))
        DynamicPPL.assert_dist(var"##temp_right#385", msg = "Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions on line 340.")

        DynamicPPL.acclogp!(var"##vi#368", DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#385", 1.5, var"##vi#368"))

        var"##temp_right#391" = Normal(m, sqrt(s))
        DynamicPPL.assert_dist(var"##temp_right#391", msg = "Right-hand side of a ~ must be subtype of Distribution or a vector of Distributions on line 340.")

        DynamicPPL.acclogp!(var"##vi#368", DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#391", 2.0, var"##vi#368"))

        return (s, m)
    end

    return DynamicPPL.Model(var"##inner_function#371", NamedTuple(), DynamicPPL.ModelGen{()}(var"##gdemo_d#397", NamedTuple()))
end

gdemo_d = DynamicPPL.ModelGen{()}(var"##gdemo_d#397", NamedTuple())

and used this model for testing but Julia crashes still. The stacktrace points to the line

var"##out#380" = DynamicPPL.tilde(var"##ctx#367", var"##sampler#369", var"##temp_right#379", var"##vn#382", var"##inds#383", var"##vi#368")

where m is sampled.

devmotion commented 4 years ago

I don't know how I missed this issue when looking through possibly related Zygote issues: https://github.com/FluxML/Zygote.jl/issues/486 So it seems other people have seen the same error with Zygote before.

devmotion commented 4 years ago

The problem is caused by https://github.com/TuringLang/DynamicPPL.jl/blob/d3ff24210ab6d7a0d994839d36148ddfca9f8d2b/src/context_implementations.jl#L135. Without this line the gradient can be computed as expected (I called a custom version in my custom model since of course this line is needed when sampling from the model).

devmotion commented 4 years ago

I'm just dreaming but maybe we could avoid these issues with mutation and type instabilities by replacing VarInfo with a strictly typed structure that's basically an extended NamedTuple such that the type tells us from which VarNames we have sampled and of which type their samples are? Instead of push!ing we would just create a new VarInfo with one additional sample (like merging/extending a NamedTuple). Additionally, the samples would not be vectorized and would not be transformed but we would just store the samples as they are and get rid of all the vectorization and transformation problems currently. Transformation or vectorization would just be done when it is actually needed (e.g., when computing HMC steps), since then all information about the samples (such as dimension etc) is available at all times without any hacks. When executing assume statements with SampleFromPrior or SampleUniform we would just dispatch to different methods that either just return the current sample or add a new sample. It seems then every execution of the model could be type stable and computing gradients should be fine with Zygote as well.

mohamed82008 commented 4 years ago

Nice investigation! So there seems to be an issue with Zygote and Julia 1.4 here which may be caused by the push! branch as you investigated. If that's the case, there is a simple workaround here. Wrap push! in a function and define @nograd on it. This branch will never be reached when computing gradients anyways.

mohamed82008 commented 4 years ago

Your idea on VarInfo is reasonable for some model classes and would be interesting to experiment with but imagine a model where we loop over the unknown number of random variables in a non-contiguous order and where we need to pause the model and fork for particle samplers. For that, the current approach is probably the best we can do. Although after implementing the proposal in https://github.com/TuringLang/DynamicPPL.jl/issues/4, your suggested design for VarInfo can become feasible for any model and all non-particle samplers.

devmotion commented 4 years ago

Wrap push! in a function and define @nograd on it. This branch will never be reached when computing gradients anyways.

Yes that's what I thought as well.

devmotion commented 4 years ago

imagine a model where we loop over the unknown number of random variables in a non-contiguous order and where we need to pause the model and fork for particle samplers. For that, the current approach is probably the best we can do. Although after implementing the proposal in TuringLang/DynamicPPL.jl#4, your suggested design for VarInfo can become feasible for any model and all non-particle samplers.

What exactly do you have in mind? I don't see why an unknown number of random variables would be a problem. If, e.g., there is a stochastic number of random variables, you could just initialize an uninitialized vector for these random variables. Or you could initialize an empty vector and push to that vector in the model definition - in this case the mutation is caused by the user and hence she can't expect Zygote to work.

mohamed82008 commented 4 years ago

Consider this:

x = Vector{T}(undef, 10)
for i in 10:-1:1
    x[i] ~ Normal()
end

VarInfo only sees the ~ lines. It doesn't know the size of x ahead of time. When it sees x[10] first, it doesn't know if the next element is x[11] or x[9] or some other index. So preallocating x in VarInfo doesn't work in this case unless we wait untill the end to put the sample in VarInfo. But this will require the nested model approach to get "the sample" as a NamedTuple then change it to VarInfo at the end.

devmotion commented 4 years ago

In this case one would just naturally get variables x[10], x[9], and so on with their sampled values in the named tuple approach. However, if one wants to have only one vector-valued variable x, one could add the syntax

@variable x

to indicate that x is a variable. One would have to state that below the initialization of x, i.e.

x = Vector{T}(undef, 10)
@variable x

then the shape and type of x is also known.