TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2k stars 216 forks source link

missing keyword arguments not property processed #2259

Open daeh opened 3 weeks ago

daeh commented 3 weeks ago

The documentation gives an example of how the sample macro can be used to either condition a model or sample RVs:

@model function gdemo(x, ::Type{T}=Float64) where {T}
    if x === missing
        # Initialize `x` if missing
        x = Vector{T}(undef, 2)
    end
    s² ~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s²))
    for i in eachindex(x)
        x[i] ~ Normal(m, sqrt(s²))
    end
end

# Construct a model with x = missing
model = gdemo(missing)
c = sample(model, HMC(0.01, 5), 500)

If x is turned into a keyword argument, this example produces an error:

using Turing

@model function gdemo_kw(::Type{T}=Float64; x=missing) where {T}
    if x === missing
        # Initialize `x` if missing
        x = Vector{T}(undef, 2)
    end
    s² ~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s²))
    for i in eachindex(x)
        x[i] ~ Normal(m, sqrt(s²))
    end
end

# Construct a model with x = missing
model_kw = gdemo_kw(; x=missing)
c_kw = sample(model_kw, HMC(0.01, 5), 500)
julia> c_kw = sample(model_kw, HMC(0.01, 5), 500)
ERROR: DomainError with Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}}(NaN,NaN,NaN):
Normal: the condition σ >= zero(σ) is not satisfied.
Stacktrace:
  [1] #371
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:37 [inlined]
  [2] check_args
    @ ~/.julia/packages/Distributions/ji8PW/src/utils.jl:89 [inlined]
  [3] #Normal#370
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:37 [inlined]
  [4] Normal
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:36 [inlined]
  [5] Normal
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:42 [inlined]
  [6] gdemo_kw(__model__::DynamicPPL.Model{…}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{…}, __context__::DynamicPPL.SamplingContext{…}, arg#225::DynamicPPL.TypeWrap{…}; x::Missing)
    @ Main ./REPL[2]:7
  [7] gdemo_kw
    @ ./REPL[2]:1 [inlined]
  [8] _evaluate!!
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:963 [inlined]
  [9] evaluate_threadsafe!!(model::DynamicPPL.Model{…}, varinfo::DynamicPPL.TypedVarInfo{…}, context::DynamicPPL.SamplingContext{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:952
 [10] evaluate!!(model::DynamicPPL.Model{…}, varinfo::DynamicPPL.TypedVarInfo{…}, context::DynamicPPL.SamplingContext{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:887
 [11] logdensity
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/logdensityfunction.jl:94 [inlined]
 [12] Fix1
    @ ./operators.jl:1118 [inlined]
 [13] vector_mode_dual_eval!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/apiutils.jl:24 [inlined]
 [14] vector_mode_gradient!(result::DiffResults.MutableDiffResult{…}, f::Base.Fix1{…}, x::Vector{…}, cfg::ForwardDiff.GradientConfig{…})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:96
 [15] gradient!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:37 [inlined]
 [16] gradient!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:35 [inlined]
 [17] logdensity_and_gradient
    @ ~/.julia/packages/LogDensityProblemsAD/rBlLq/ext/LogDensityProblemsADForwardDiffExt.jl:118 [inlined]
 [18] ∂logπ∂θ
    @ ~/.julia/packages/Turing/IyijE/src/mcmc/hmc.jl:159 [inlined]
 [19] ∂H∂θ
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/hamiltonian.jl:38 [inlined]
 [20] step(lf::AdvancedHMC.Leapfrog{…}, h::AdvancedHMC.Hamiltonian{…}, z::AdvancedHMC.PhasePoint{…}, n_steps::Int64; fwd::Bool, full_trajectory::Val{…})
    @ AdvancedHMC ~/.julia/packages/AdvancedHMC/AlvV4/src/integrator.jl:229
 [21] step
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/integrator.jl:199 [inlined]
 [22] sample_phasepoint
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/trajectory.jl:323 [inlined]
 [23] transition
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/trajectory.jl:262 [inlined]
 [24] transition(rng::Random.TaskLocalRNG, h::AdvancedHMC.Hamiltonian{…}, κ::AdvancedHMC.HMCKernel{…}, z::AdvancedHMC.PhasePoint{…})
    @ AdvancedHMC ~/.julia/packages/AdvancedHMC/AlvV4/src/sampler.jl:59
 [25] step(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}, state::Turing.Inference.HMCState{…}; nadapts::Int64, kwargs::@Kwargs{})
    @ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/hmc.jl:240
 [26] step(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}, state::Turing.Inference.HMCState{…})
    @ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/hmc.jl:226
 [27] macro expansion
    @ ~/.julia/packages/AbstractMCMC/YrmkI/src/sample.jl:176 [inlined]
 [28] macro expansion
    @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
 [29] macro expansion
    @ ~/.julia/packages/AbstractMCMC/YrmkI/src/logging.jl:9 [inlined]
 [30] mcmcsample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/YrmkI/src/sample.jl:120
 [31] sample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; chain_type::Type, resume_from::Nothing, initial_state::Nothing, kwargs::@Kwargs{})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/sampler.jl:93
 [32] sample
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/sampler.jl:83 [inlined]
 [33] #sample#4
    @ ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:263 [inlined]
 [34] sample
    @ ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:256 [inlined]
 [35] #sample#3
    @ ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:253 [inlined]
 [36] sample(model::DynamicPPL.Model{…}, alg::HMC{…}, N::Int64)
    @ Turing.Inference ~/.julia/packages/Turing/IyijE/src/mcmc/Inference.jl:247
 [37] top-level scope
    @ REPL[4]:1
Some type information was truncated. Use `show(err)` to see complete types.
Julia Version 1.10.4
Commit 48d4fd48430 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 12 × Apple M2 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
Threads: 8 default, 0 interactive, 4 GC (on 8 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 8

  [0bf59076] AdvancedHMC v0.6.1
  [cbdf2221] AlgebraOfGraphics v0.6.19
  [c7e460c6] ArgParse v1.2.0
  [131c737c] ArviZ v0.10.5
  [4a6e88f0] ArviZPythonPlots v0.1.5
  [336ed68f] CSV v0.10.14
⌃ [13f3f980] CairoMakie v0.11.11
  [324d7699] CategoricalArrays v0.10.8
  [a93c6f00] DataFrames v1.6.1
  [1a297f60] FillArrays v1.11.0
  [663a7486] FreeTypeAbstraction v0.10.3
  [682c06a0] JSON v0.21.4
  [98e50ef6] JuliaFormatter v1.0.56
⌅ [ee78f7c6] Makie v0.20.10
  [7f7a1694] Optimization v3.25.1
  [b1d3bc72] Pathfinder v0.8.7
  [f27b6e38] Polynomials v4.0.9
  [438e738f] PyCall v1.96.4
  [37e2e3b7] ReverseDiff v1.15.3
  [295af30f] Revise v3.5.14
  [2913bbd2] StatsBase v0.34.3
  [f3b207a7] StatsPlots v0.15.7
  [fce5fe82] Turing v0.32.3
  [e88e6eb3] Zygote v0.6.70
daeh commented 3 weeks ago

tried again with v0.33.0

julia> c_kw = sample(model_kw, HMC(0.01, 5), 500)
ERROR: DomainError with Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}}(NaN,NaN,NaN):
Normal: the condition σ >= zero(σ) is not satisfied.
Stacktrace:
  [1] #371
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:37 [inlined]
  [2] check_args
    @ ~/.julia/packages/Distributions/ji8PW/src/utils.jl:89 [inlined]
  [3] #Normal#370
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:37 [inlined]
  [4] Normal
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:36 [inlined]
  [5] Normal
    @ ~/.julia/packages/Distributions/ji8PW/src/univariate/continuous/normal.jl:42 [inlined]
  [6] gdemo_kw(__model__::DynamicPPL.Model{…}, __varinfo__::DynamicPPL.ThreadSafeVarInfo{…}, __context__::DynamicPPL.SamplingContext{…}, arg#225::DynamicPPL.TypeWrap{…}; x::Missing)
    @ Main ./REPL[5]:7
  [7] gdemo_kw
    @ ./REPL[5]:1 [inlined]
  [8] _evaluate!!
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:963 [inlined]
  [9] evaluate_threadsafe!!(model::DynamicPPL.Model{…}, varinfo::DynamicPPL.TypedVarInfo{…}, context::DynamicPPL.SamplingContext{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:952
 [10] evaluate!!(model::DynamicPPL.Model{…}, varinfo::DynamicPPL.TypedVarInfo{…}, context::DynamicPPL.SamplingContext{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/model.jl:887
 [11] logdensity
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/logdensityfunction.jl:94 [inlined]
 [12] Fix1
    @ ./operators.jl:1118 [inlined]
 [13] vector_mode_dual_eval!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/apiutils.jl:24 [inlined]
 [14] vector_mode_gradient!(result::DiffResults.MutableDiffResult{…}, f::Base.Fix1{…}, x::Vector{…}, cfg::ForwardDiff.GradientConfig{…})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:96
 [15] gradient!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:37 [inlined]
 [16] gradient!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/gradient.jl:35 [inlined]
 [17] logdensity_and_gradient
    @ ~/.julia/packages/LogDensityProblemsAD/rBlLq/ext/LogDensityProblemsADForwardDiffExt.jl:118 [inlined]
 [18] ∂logπ∂θ
    @ ~/.julia/packages/Turing/iRdIB/src/mcmc/hmc.jl:159 [inlined]
 [19] ∂H∂θ
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/hamiltonian.jl:38 [inlined]
 [20] step(lf::AdvancedHMC.Leapfrog{…}, h::AdvancedHMC.Hamiltonian{…}, z::AdvancedHMC.PhasePoint{…}, n_steps::Int64; fwd::Bool, full_trajectory::Val{…})
    @ AdvancedHMC ~/.julia/packages/AdvancedHMC/AlvV4/src/integrator.jl:229
 [21] step
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/integrator.jl:199 [inlined]
 [22] sample_phasepoint
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/trajectory.jl:323 [inlined]
 [23] transition
    @ ~/.julia/packages/AdvancedHMC/AlvV4/src/trajectory.jl:262 [inlined]
 [24] transition(rng::Random.TaskLocalRNG, h::AdvancedHMC.Hamiltonian{…}, κ::AdvancedHMC.HMCKernel{…}, z::AdvancedHMC.PhasePoint{…})
    @ AdvancedHMC ~/.julia/packages/AdvancedHMC/AlvV4/src/sampler.jl:59
 [25] step(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}, state::Turing.Inference.HMCState{…}; nadapts::Int64, kwargs::@Kwargs{})
    @ Turing.Inference ~/.julia/packages/Turing/iRdIB/src/mcmc/hmc.jl:240
 [26] step(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}, state::Turing.Inference.HMCState{…})
    @ Turing.Inference ~/.julia/packages/Turing/iRdIB/src/mcmc/hmc.jl:226
 [27] macro expansion
    @ ~/.julia/packages/AbstractMCMC/YrmkI/src/sample.jl:176 [inlined]
 [28] macro expansion
    @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
 [29] macro expansion
    @ ~/.julia/packages/AbstractMCMC/YrmkI/src/logging.jl:9 [inlined]
 [30] mcmcsample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/YrmkI/src/sample.jl:120
 [31] sample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; chain_type::Type, resume_from::Nothing, initial_state::Nothing, kwargs::@Kwargs{})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/E4kDs/src/sampler.jl:93
 [32] sample
    @ ~/.julia/packages/DynamicPPL/E4kDs/src/sampler.jl:83 [inlined]
 [33] #sample#4
    @ ~/.julia/packages/Turing/iRdIB/src/mcmc/Inference.jl:263 [inlined]
 [34] sample
    @ ~/.julia/packages/Turing/iRdIB/src/mcmc/Inference.jl:256 [inlined]
 [35] #sample#3
    @ ~/.julia/packages/Turing/iRdIB/src/mcmc/Inference.jl:253 [inlined]
 [36] sample(model::DynamicPPL.Model{…}, alg::HMC{…}, N::Int64)
    @ Turing.Inference ~/.julia/packages/Turing/iRdIB/src/mcmc/Inference.jl:247
 [37] top-level scope
    @ REPL[7]:1
Some type information was truncated. Use `show(err)` to see complete types.
Commit 48d4fd48430 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 12 × Apple M2 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
Threads: 8 default, 0 interactive, 4 GC (on 8 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 8

Status `~/coding/-GitRepos/knobe-counterfactuals/Project.toml`
  [98e50ef6] JuliaFormatter v1.0.56
  [295af30f] Revise v3.5.14
  [fce5fe82] Turing v0.33.0
torfjelde commented 2 weeks ago

TBH I'm uncertain if this is intended or not, but I do agree that the kwargs should be treated similarly to the argument.

One way you can easily check what's considered random and what's considered "observed" is to just sample from the model:

rand(Turing.OrderedDict, model)
torfjelde commented 2 weeks ago

Btw, we generally recommend using condition instead of passing in observations as model args / kwargs these days. That is, write your model as

@model function gdemo(x, ::Type{T}=Float64) where {T}
    s² ~ InverseGamma(2, 3)
    m ~ Normal(0, sqrt(s²))
    x = Vector{T}(undef, 2)
    for i in eachindex(x)
        x[i] ~ Normal(m, sqrt(s²))
    end
end

model = gdemo()
model_cond = model | (x = x_data,)

Going forward this will be the recommended way of doing things.

daeh commented 2 weeks ago

Thanks for the quick fix! Yes, I'll start using the condition syntax from here on out. I posted the issue primarily because the kwarg behavior was very unexpected (took me a while to figure out what the issue was), and I imagined it could trip up other Turing newbies too. Thanks!

torfjelde commented 1 week ago

Thank you for bringing up the issue:) Was not aware of this bug until you brought it up.