TuringLang / DynamicPPL.jl

Implementation of domain-specific language (DSL) for dynamic probabilistic programming
MIT License
157 stars 26 forks source link

Sampling from model prior with `missing` is thread-unsafe #641

Open penelopeysm opened 3 weeks ago

penelopeysm commented 3 weeks ago

Reported via Julia Slack: https://julialang.slack.com/archives/CCYDC34A0/p1723554415531969

MWE with current master branch of Turing:

using Turing 

@model function demo_threading(x)
    s ~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s))
    Threads.@threads for i in eachindex(x)
        x[i] ~ Normal(m, sqrt(s))
    return (s, m)

demo_threading(fill(missing, 1000))()

The last line of this usually crashes, the error message tends to vary. Sometimes it's destination has fewer elements than required, sometimes it's KeyError: key x[1] not found, etc. and sometimes it actually runs correctly!

Everything behaves as expected if you don't run Julia with multiple threads (-t4 argument below) or if you remove Threads.@threads from the above.

$ julia --project=. -t4

julia> using Turing

julia> @model function demo_threading(x)
           s ~ InverseGamma(2, 3)
           m ~ Normal(0, sqrt(s))
           Threads.@threads for i in eachindex(x)
               x[i] ~ Normal(m, sqrt(s))
           return (s, m)
demo_threading (generic function with 2 methods)

julia> demo_threading(fill(missing, 1000))()
ERROR: TaskFailedException

    nested task error: ArgumentError: destination has fewer elements than required
      [1] copyto!(dest::Vector{AbstractPPL.VarName}, src::Base.KeySet{AbstractPPL.VarName, Dict{AbstractPPL.VarName, Int64}})
        @ Base ./abstractarray.jl:944
      [2] _collect
        @ ./array.jl:765 [inlined]
      [3] collect
        @ ./array.jl:759 [inlined]
      [4] istrans
        @ ~/.julia/packages/DynamicPPL/DvdZw/src/abstract_varinfo.jl:529 [inlined]
      [5] assume(rng::Random.TaskLocalRNG, sampler::DynamicPPL.SampleFromPrior, dist::Normal{…}, vn::AbstractPPL.VarName{…}, vi::DynamicPPL.ThreadSafeVarInfo{…})
        @ DynamicPPL ~/.julia/packages/DynamicPPL/DvdZw/src/context_implementations.jl:254
      [6] tilde_assume
        @ ~/.julia/packages/DynamicPPL/DvdZw/src/context_implementations.jl:72 [inlined]
      [7] tilde_assume
        @ ~/.julia/packages/DynamicPPL/DvdZw/src/context_implementations.jl:67 [inlined]
      [8] tilde_assume
        @ ~/.julia/packages/DynamicPPL/DvdZw/src/context_implementations.jl:52 [inlined]
      [9] tilde_assume!!(context::DynamicPPL.SamplingContext{…}, right::Normal{…}, vn::AbstractPPL.VarName{…}, vi::DynamicPPL.ThreadSafeVarInfo{…})
        @ DynamicPPL ~/.julia/packages/DynamicPPL/DvdZw/src/context_implementations.jl:144
     [10] (::var"#3#threadsfor_fun#2"{var"#3#threadsfor_fun#1#3"{…}})(tid::Int64; onethread::Bool)
        @ Main ./threadingconstructs.jl:215
     [11] #3#threadsfor_fun
        @ ./threadingconstructs.jl:182 [inlined]
     [12] (::Base.Threads.var"#1#2"{var"#3#threadsfor_fun#2"{var"#3#threadsfor_fun#1#3"{…}}, Int64})()
        @ Base.Threads ./threadingconstructs.jl:154

...and 3 more exceptions.

  [1] threading_run(fun::var"#3#threadsfor_fun#2"{var"#3#threadsfor_fun#1#3"{…}}, static::Bool)
    @ Base.Threads ./threadingconstructs.jl:172
  [2] macro expansion
    @ ./threadingconstructs.jl:220 [inlined]
  [3] demo_threading(__model__::DynamicPPL.Model{…}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{…}, __context__::DynamicPPL.SamplingContext{…}, x::Vector{…})
    @ Main ./REPL[2]:4
  [4] _evaluate!!
    @ ~/.julia/packages/DynamicPPL/DvdZw/src/model.jl:973 [inlined]
  [5] evaluate_threadsafe!!(model::DynamicPPL.Model{…}, varinfo::DynamicPPL.UntypedVarInfo{…}, context::DynamicPPL.SamplingContext{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/DvdZw/src/model.jl:962
  [6] evaluate!!
    @ ~/.julia/packages/DynamicPPL/DvdZw/src/model.jl:892 [inlined]
  [7] evaluate!! (repeats 2 times)
    @ ~/.julia/packages/DynamicPPL/DvdZw/src/model.jl:905 [inlined]
  [8] evaluate!!
    @ ~/.julia/packages/DynamicPPL/DvdZw/src/model.jl:915 [inlined]
  [9] (::DynamicPPL.Model{typeof(demo_threading), (:x,), (), (), Tuple{Vector{…}}, Tuple{}, DynamicPPL.DefaultContext})()
    @ DynamicPPL ~/.julia/packages/DynamicPPL/DvdZw/src/model.jl:865
 [10] top-level scope
    @ REPL[3]:1
Some type information was truncated. Use `show(err)` to see complete types.
penelopeysm commented 3 weeks ago

Oh, I just noticed that @torfjelde wrote:

Note that this only works for observe-statements, i.e. when the LHS of ~ is “fixed” / not random.

Is the usecase above something we want to support, or is this a wontfix?

yebai commented 3 weeks ago

We might be able to support this particular case. Can you investigate why it causes crashes so we can discuss this in more detail?