bat / BAT.jl

A Bayesian Analysis Toolkit in Julia
Other
198 stars 30 forks source link

Error with truncated distribution as prior #450

Closed gipert closed 5 days ago

gipert commented 1 week ago

BAT crashes deterministically when attempting to use truncated distributions as priors. If I set, in docs/src/bat_tutorial.jl:

prior = distprod(
    a = [Weibull(1.1, 5000), Weibull(1.1, 5000)],
    mu = [-2.0..0.0, 1.0..3.0],
    sigma = truncated(Normal(0, 2), lower=0)
)

I get:

> julia docs/src/bat_tutorial.jl
[ Info: Setting new default BAT context BATContext{Float64}(Random123.Philox4x{UInt64, 10}(0x8bdbac22c8306763, 0x54639feaadacd6cf, 0x6765c3d759a52d0e, 0x4c6d85b755dcf9f0, 0x1e067ade302b087d, 0xf3403de12719c1b2, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0x0000000000000000, 0), HeterogeneousComputing.CPUnit(), BAT._NoADSelected())
[ Info: MCMCChainPoolInit: trying to generate 4 viable MCMC chain(s).
┌ Debug: Generating dummy MCMC chain to determine chain, output and tuner types.
└ @ BAT ~/.julia/packages/BAT/nxtXP/src/samplers/mcmc/chain_pool_init.jl:80
ERROR: LoadError: ArgumentError: Can't derive numeric type for type Nothing
Stacktrace:
  [1] realnumtype(::Type{Nothing})
    @ ValueShapes ~/.julia/packages/ValueShapes/rT1Zi/src/value_shape.jl:22
  [2] map
    @ ./tuple.jl:293 [inlined]
  [3] map
    @ ./tuple.jl:294 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/ValueShapes/rT1Zi/src/value_shape.jl:31 [inlined]
  [5] realnumtype(::Type{Tuple{Float64, Float64, Float64, Nothing}})
    @ ValueShapes ~/.julia/packages/ValueShapes/rT1Zi/src/value_shape.jl:31
  [6] _dist_params_numtype(d::Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:307
  [7] _eval_dist_trafo_func
    @ ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:373 [inlined]
  [8] apply_dist_trafo(trg_d::Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}, ::BAT.StandardUvUniform{Float64}, src_v::Float64)
    @ BAT ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:395
  [9] apply_dist_trafo(trg_d::Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}, src_d::BAT.StandardUvNormal{Float64}, src_v::Float64)
    @ BAT ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:296
 [10] WithForwardDiff
    @ ~/.julia/packages/ForwardDiffPullbacks/s8kVo/src/with_forwarddiff.jl:22 [inlined]
 [11] _broadcast_getindex_evalf
    @ ./broadcast.jl:709 [inlined]
 [12] _broadcast_getindex
    @ ./broadcast.jl:682 [inlined]
 [13] getindex
    @ ./broadcast.jl:636 [inlined]
 [14] copy
    @ ./broadcast.jl:942 [inlined]
 [15] materialize
    @ ./broadcast.jl:903 [inlined]
 [16] _product_dist_trafo_impl(trg_ds::FillArrays.Fill{Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}, 1, Tuple{Base.OneTo{Int64}}}, src_ds::BAT.StandardUvNormal{Float64}, src_v::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:540
 [17] apply_dist_trafo
    @ ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:565 [inlined]
 [18] _stdmv_to_flat_ntdistelem(td::Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}, src_d::BAT.StandardMvNormal{Float64}, src_v::Vector{Float64}, src_acc::ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:615
 [19] #119
    @ ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:625 [inlined]
 [20] map
    @ ./tuple.jl:319 [inlined]
 [21] map
    @ ./tuple.jl:322 [inlined]
 [22] apply_dist_trafo(trg_d::ValueShapes.UnshapedNTD{ValueShapes.NamedTupleDist{(:a, :mu, :sigma), Tuple{Product{Continuous, Weibull{Float64}, Vector{Weibull{Float64}}}, Product{Continuous, Uniform{Float64}, Vector{Uniform{Float64}}}, Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Tuple{ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ScalarShape{Real}}}, NamedTuple}}, src_d::BAT.StandardMvNormal{Float64}, src_v::Vector{Float64})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:625
 [23] apply_dist_trafo(trg_d::ValueShapes.NamedTupleDist{(:a, :mu, :sigma), Tuple{Product{Continuous, Weibull{Float64}, Vector{Weibull{Float64}}}, Product{Continuous, Uniform{Float64}, Vector{Uniform{Float64}}}, Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Tuple{ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ScalarShape{Real}}}, NamedTuple}, src_d::BAT.StandardMvNormal{Float64}, src_v::Vector{Float64})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:630
 [24] DistributionTransform
    @ ~/.julia/packages/BAT/nxtXP/src/transforms/distribution_transform.jl:156 [inlined]
 [25] macro expansion
    @ ~/.julia/packages/FunctionChains/piKSk/src/function_chain.jl:0 [inlined]
 [26] FunctionChain
    @ ~/.julia/packages/FunctionChains/piKSk/src/function_chain.jl:161 [inlined]
 [27] logdensityof
    @ ~/.julia/packages/DensityInterface/MCyV6/src/interface.jl:256 [inlined]
 [28] logdensityof(density::PosteriorMeasure{DensityInterface.LogFuncDensity{FunctionChains.FunctionChain{Tuple{BAT.DistributionTransform{ValueShapes.NamedTupleDist{(:a, :mu, :sigma), Tuple{Product{Continuous, Weibull{Float64}, Vector{Weibull{Float64}}}, Product{Continuous, Uniform{Float64}, Vector{Uniform{Float64}}}, Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Tuple{ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ScalarShape{Real}}}, NamedTuple}, BAT.StandardMvNormal{Float64}}, var"#3#4"{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Int64}, typeof(fit_function)}}}}, BAT.BATDistMeasure{BAT.StandardMvNormal{Float64}}}, v::Vector{Float64})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/measures/posterior_measure.jl:59
 [29] BAT.MHIterator(algorithm::MetropolisHastings{TDist{Float64}, RepetitionWeighting{Int64}, AdaptiveMHTuning}, target::PosteriorMeasure{DensityInterface.LogFuncDensity{FunctionChains.FunctionChain{Tuple{BAT.DistributionTransform{ValueShapes.NamedTupleDist{(:a, :mu, :sigma), Tuple{Product{Continuous, Weibull{Float64}, Vector{Weibull{Float64}}}, Product{Continuous, Uniform{Float64}, Vector{Uniform{Float64}}}, Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Tuple{ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ScalarShape{Real}}}, NamedTuple}, BAT.StandardMvNormal{Float64}}, var"#3#4"{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Int64}, typeof(fit_function)}}}}, BAT.BATDistMeasure{BAT.StandardMvNormal{Float64}}}, info::BAT.MCMCIteratorInfo, x_init::Vector{Float64}, context::BATContext{Float64, Random123.Philox4x{UInt64, 10}, HeterogeneousComputing.CPUnit, BAT._NoADSelected})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/samplers/mcmc/mh/mh_sampler.jl:92
 [30] MCMCIterator
    @ ~/.julia/packages/BAT/nxtXP/src/samplers/mcmc/mh/mh_sampler.jl:135 [inlined]
 [31] mcmc_init!(algorithm::MetropolisHastings{TDist{Float64}, RepetitionWeighting{Int64}, AdaptiveMHTuning}, density::PosteriorMeasure{DensityInterface.LogFuncDensity{FunctionChains.FunctionChain{Tuple{BAT.DistributionTransform{ValueShapes.NamedTupleDist{(:a, :mu, :sigma), Tuple{Product{Continuous, Weibull{Float64}, Vector{Weibull{Float64}}}, Product{Continuous, Uniform{Float64}, Vector{Uniform{Float64}}}, Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Tuple{ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ScalarShape{Real}}}, NamedTuple}, BAT.StandardMvNormal{Float64}}, var"#3#4"{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Int64}, typeof(fit_function)}}}}, BAT.BATDistMeasure{BAT.StandardMvNormal{Float64}}}, nchains::Int64, init_alg::MCMCChainPoolInit, tuning_alg::AdaptiveMHTuning, nonzero_weights::Bool, callback::Function, context::BATContext{Float64, Random123.Philox4x{UInt64, 10}, HeterogeneousComputing.CPUnit, BAT._NoADSelected})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/samplers/mcmc/chain_pool_init.jl:85
 [32] bat_sample_impl(m::PosteriorMeasure{DensityInterface.LogFuncDensity{var"#3#4"{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Int64}, typeof(fit_function)}}, BAT.BATDistMeasure{ValueShapes.NamedTupleDist{(:a, :mu, :sigma), Tuple{Product{Continuous, Weibull{Float64}, Vector{Weibull{Float64}}}, Product{Continuous, Uniform{Float64}, Vector{Uniform{Float64}}}, Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Tuple{ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ScalarShape{Real}}}, NamedTuple}}}, algorithm::MCMCSampling{MetropolisHastings{TDist{Float64}, RepetitionWeighting{Int64}, AdaptiveMHTuning}, PriorToGaussian, MCMCChainPoolInit, MCMCMultiCycleBurnin, BrooksGelmanConvergence, typeof(BAT.nop_func)}, context::BATContext{Float64, Random123.Philox4x{UInt64, 10}, HeterogeneousComputing.CPUnit, BAT._NoADSelected})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/samplers/mcmc/mcmc_sample.jl:46
 [33] bat_sample(target::PosteriorMeasure{DensityInterface.LogFuncDensity{var"#3#4"{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Int64}, typeof(fit_function)}}, BAT.BATDistMeasure{ValueShapes.NamedTupleDist{(:a, :mu, :sigma), Tuple{Product{Continuous, Weibull{Float64}, Vector{Weibull{Float64}}}, Product{Continuous, Uniform{Float64}, Vector{Uniform{Float64}}}, Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Tuple{ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ScalarShape{Real}}}, NamedTuple}}}, algorithm::MCMCSampling{MetropolisHastings{TDist{Float64}, RepetitionWeighting{Int64}, AdaptiveMHTuning}, PriorToGaussian, MCMCChainPoolInit, MCMCMultiCycleBurnin, BrooksGelmanConvergence, typeof(BAT.nop_func)}, context::BATContext{Float64, Random123.Philox4x{UInt64, 10}, HeterogeneousComputing.CPUnit, BAT._NoADSelected})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/algotypes/sampling_algorithm.jl:56
 [34] bat_sample(target::PosteriorMeasure{DensityInterface.LogFuncDensity{var"#3#4"{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Int64}, typeof(fit_function)}}, BAT.BATDistMeasure{ValueShapes.NamedTupleDist{(:a, :mu, :sigma), Tuple{Product{Continuous, Weibull{Float64}, Vector{Weibull{Float64}}}, Product{Continuous, Uniform{Float64}, Vector{Uniform{Float64}}}, Truncated{Normal{Float64}, Continuous, Float64, Float64, Nothing}}, Tuple{ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ArrayShape{Real, 1}}, ValueShapes.ValueAccessor{ValueShapes.ScalarShape{Real}}}, NamedTuple}}}, algorithm::MCMCSampling{MetropolisHastings{TDist{Float64}, RepetitionWeighting{Int64}, AdaptiveMHTuning}, PriorToGaussian, MCMCChainPoolInit, MCMCMultiCycleBurnin, BrooksGelmanConvergence, typeof(BAT.nop_func)})
    @ BAT ~/.julia/packages/BAT/nxtXP/src/algotypes/sampling_algorithm.jl:67
 [35] top-level scope
    @ ~/sw/src/BAT.jl/docs/src/bat_tutorial.jl:84
in expression starting at /home/gipert/sw/src/BAT.jl/docs/src/bat_tutorial.jl:84

any workaround?

oschulz commented 1 week ago

That's strange, we have used truncated priors a lot - I'll fix this.

oschulz commented 5 days ago

Ah, with d = truncated(Normal(0, 2), lower=0) we have d.upper isa Nothing, which ValueShapes.realnumtype currently can't handle.

Just use truncated(Normal(0, 2), 0, Inf) for now, I'll fix this in ValueShapes.

oschulz commented 5 days ago

Will be fixed by https://github.com/oschulz/ValueShapes.jl/pull/78

oschulz commented 5 days ago

Fixed in ValueShapes v0.11.3 .

gipert commented 5 days ago

Thanks Oli!