cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
413 stars 30 forks source link

Unable to sample posterior of the MC example from JuliaCon 2021 #347

Open AxLamelas opened 2 years ago

AxLamelas commented 2 years ago

Hello, I'm trying to reproduce the example using Markov chains from JuliaCon 2021 but I cannot get the sampling to work. I saw issue #293 I tried the code from there it still fails in the same way. Here is the code I'm using:

using Soss
using Base.Iterators
using SampleChainsDynamicHMC

mc_init = @model begin
    x ~ Normal(0.0, 1.0)
    return (x=x,)
end

mc_step = @model p, s begin
    x ~ Normal(s.x + p.Δμ, p.σ)
    return (x = x,)
end

m = @model begin
    Δμ ~ Normal()
    σ ~ HalfNormal()
    p = (; Δμ, σ)
    mc ~  Chain(mc_init()) do s mc_step(s=s,p=p) end 
    return mc
end

obs = take(rand(m()),400) |> collect
post = m() | (mc = obs,)
chain = sample(post,dynamichmc())

and here is the error

ERROR: MethodError: no method matching logdensity_def(::Soss.ConditionalModel{NamedTuple{()}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}, ::NamedTuple{(:x,), Tuple{Float64}})
Closest candidates are:
  logdensity_def(::Sorted, ::Any) at C:\Users\AxLamelas\.julia\packages\MeasureTheory\GI8wX\src\transforms\ordered.jl:75
  logdensity_def(::Laplace{()}, ::Any) at C:\Users\AxLamelas\.julia\packages\MeasureTheory\GI8wX\src\parameterized\laplace.jl:24
  logdensity_def(::Laplace{(:μ,)}, ::Any) at C:\Users\AxLamelas\.julia\packages\MeasureBase\Tfp1d\src\proxies.jl:17
  ...
Stacktrace:
  [1] logdensity_def
    @ C:\Users\AxLamelas\.julia\packages\MeasureTheory\GI8wX\src\combinators\chain.jl:28 [inlined]
  [2] unsafe_logdensityof(μ::Chain{GeneralizedGenerated.Closure{function = (p, s;) -> begin
    (Main).mc_step(s = s, p = p)
end, Tuple{NamedTuple{(:Δμ, :σ), Tuple{ForwardDiff.Dual{ForwardDiff.Tag{LogDensityProblems.var"#46#47"{LogDensityProblems.TransformedLogDensity{TransformVariables.TransformTuple{NamedTuple{(:σ, :Δμ), Tuple{TransformVariables.ShiftedExp{true, Int64}, TransformVariables.Identity}}}, Soss.var"#ℓ#154"{Soss.ConditionalModel{NamedTuple{()}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}, NamedTuple{(), Tuple{}}, NamedTuple{(:mc,), Tuple{Vector{Any}}}}}}}, Float64}, Float64, 2}, ForwardDiff.Dual{ForwardDiff.Tag{LogDensityProblems.var"#46#47"{LogDensityProblems.TransformedLogDensity{TransformVariables.TransformTuple{NamedTuple{(:σ, :Δμ), Tuple{TransformVariables.ShiftedExp{true, Int64}, TransformVariables.Identity}}}, Soss.var"#ℓ#154"{Soss.ConditionalModel{NamedTuple{()}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}, NamedTuple{(), Tuple{}}, NamedTuple{(:mc,), Tuple{Vector{Any}}}}}}}, Float64}, Float64, 2}}}}}, Soss.ConditionalModel{NamedTuple{()}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, x::Vector{Any})
    @ MeasureBase C:\Users\AxLamelas\.julia\packages\MeasureBase\Tfp1d\src\density.jl:139

I get this error both with the windows version of Julia and using WSL2.

Package verions Status \`C:\Users\AxLamelas\Documents\Personal\Credit simulations\Julia\Manifest.toml\` [398f06c4] AbstractLattices v0.2.1 [1520ce14] AbstractTrees v0.4.2 [7d9f7c33] Accessors v0.1.17 [79e6a3ab] Adapt v3.3.3 [dce04be8] ArgCheck v2.3.0 [4fba245c] ArrayInterface v6.0.17 [30b0a656] ArrayInterfaceCore v0.1.14 [b0d46f97] ArrayInterfaceStaticArrays v0.1.2 [dd5226c6] ArrayInterfaceStaticArraysCore v0.1.0 [4c555306] ArrayLayouts v0.8.9 [65a8f2f4] ArraysOfArrays v0.5.10 [15f4f7f2] AutoHashEquals v0.2.0 [198e06fe] BangBang v0.3.36 [9718e550] Baselet v0.1.1 [e2ed5e7c] Bijections v0.1.4 [49dc2e85] Calculus v0.5.1 [d360d2e6] ChainRulesCore v1.15.2 [9e997f8a] ChangesOfVariables v0.1.4 [861a8166] Combinatorics v1.0.2 [bbf7d656] CommonSubexpressions v0.3.0 [34da2185] Compat v3.45.0 [a33af91c] CompositionsBase v0.1.1 [2569d6c7] ConcreteStructs v0.2.3 [187b0558] ConstructionBase v1.4.0 [9a962f9c] DataAPI v1.10.0 [864edb3b] DataStructures v0.18.13 [e2d170a0] DataValueInterfaces v1.0.0 [244e2a9f] DefineSingletons v0.1.2 [b429d917] DensityInterface v0.4.0 [163ba53b] DiffResults v1.0.3 [b552c78f] DiffRules v1.11.0 [31c24e10] Distributions v0.25.65 [ffbed154] DocStringExtensions v0.8.6 [fa6b7ba4] DualNumbers v0.6.8 [bbc10e6e] DynamicHMC v3.1.1 [6c76993d] DynamicIterators v0.4.2 [7c1d4256] DynamicPolynomials v0.4.5 [fdbdab4c] ElasticArrays v1.2.10 [e2ba6199] ExprTools v0.1.8 [1a297f60] FillArrays v0.13.2 [6a86dc24] FiniteDiff v2.13.1 [b67e1e5a] FlexLinearAlgebra v0.1.0 [f6369f11] ForwardDiff v0.10.30 [46192b85] GPUArraysCore v0.1.0 [6b9d7cbe] GeneralizedGenerated v0.3.3 [34004b35] HypergeometricFunctions v0.3.11 [615f187c] IfElse v0.1.1 [e1ba4f0e] Infinities v0.1.4 [22cec73e] InitialValues v0.3.1 [842dd82b] InlineStrings v1.1.4 [18e54dd8] IntegerMathUtils v0.1.0 [d8418881] Intervals v1.8.0 [3587e190] InverseFunctions v0.1.7 [92d709cd] IrrationalConstants v0.1.1 [c8e1da08] IterTools v1.4.0 [82899510] IteratorInterfaceExtensions v1.0.0 [692b3bcd] JLLWrappers v1.4.1 [b14d175d] JuliaVariables v0.2.4 [4d827475] KeywordCalls v0.2.4 [2ee39098] LabelledArrays v1.11.1 [5078a376] LazyArrays v0.22.11 [9c8b4983] LightXML v0.9.0 [d3d80556] LineSearches v7.1.1 [9b3f67b0] LinearAlgebraX v0.1.9 [6fdf6af0] LogDensityProblems v0.11.4 [2ab3a3ac] LogExpFunctions v0.3.15 [aa2f6b4e] LogarithmicNumbers v1.2.0 [6e857e4b] MCMCDiagnostics v0.3.0 [d8e11817] MLStyle v0.4.13 [1914dd2f] MacroTools v0.5.9 [dbb5928d] MappedArrays v0.4.1 [a3b82374] MatrixFactorizations v0.9.1 [fa1605e6] MeasureBase v0.9.4 [eadaa1a4] MeasureTheory v0.16.2 [eff96d63] Measurements v2.7.2 [e9d8d322] Metatheory v1.3.4 [128add7d] MicroCollections v0.1.2 [e1d29d7a] Missings v1.0.2 [78c3b35d] Mocking v0.7.3 [7475f97c] Mods v1.3.2 [3b2b4ff1] Multisets v0.4.4 [102ac46a] MultivariatePolynomials v0.4.6 [d8a4904e] MutableArithmetics v1.0.4 [d41bc354] NLSolversBase v7.8.2 [77ba4419] NaNMath v0.3.7 [71a1bf82] NameResolution v0.1.5 [d9ec5142] NamedTupleTools v0.14.0 [a734d2a7] NestedTuples v0.3.10 [429524aa] Optim v1.7.0 [bac558e1] OrderedCollections v1.4.1 [90014a1f] PDMats v0.11.16 [d96e819e] Parameters v0.12.3 [69de0a69] Parsers v2.3.2 [2ae35dd2] Permutations v0.4.14 [f27b6e38] Polynomials v2.0.25 [85a6dd25] PositiveFactorizations v0.2.4 [d236fae5] PreallocationTools v0.4.0 [21216c6a] Preferences v1.3.0 [8162dcfd] PrettyPrint v0.2.0 [54e16d92] PrettyPrinting v0.4.0 [27ebfcd6] Primes v0.5.3 [92933f4c] ProgressMeter v1.7.2 [1fd47b50] QuadGK v2.4.2 [3cdcf5f2] RecipesBase v1.2.1 [731186ca] RecursiveArrayTools v2.31.2 [189a3867] Reexport v1.2.2 [42d2dcc6] Referenceables v0.1.2 [ae029012] Requires v1.3.0 [286e9d63] RingLists v0.2.7 [79098fc4] Rmath v0.7.0 [7e49a35a] RuntimeGeneratedFunctions v0.5.3 [754583d1] SampleChains v0.5.1 [6d9fd711] SampleChainsDynamicHMC v0.3.5 [6c6a2e73] Scratch v1.1.1 [efcf1570] Setfield v0.8.2 [55797a34] SimpleGraphs v0.7.18 [ec83eff0] SimplePartitions v0.3.0 [cc47b68c] SimplePolynomials v0.2.9 [b2aef97b] SimplePosets v0.1.5 [a6525b86] SimpleRandom v0.3.1 [a2af1166] SortingAlgorithms v1.0.1 [8ce77f84] Soss v0.21.0 [276daf66] SpecialFunctions v2.1.7 [171d559e] SplittablesBase v0.1.14 [aedffcd0] Static v0.6.6 [90137ffa] StaticArrays v1.5.1 [1e83bf80] StaticArraysCore v1.0.1 [82ae8749] StatsAPI v1.4.0 [2913bbd2] StatsBase v0.33.19 [4c63d2b9] StatsFuns v1.0.1 [09ab397b] StructArrays v0.6.11 [d1185830] SymbolicUtils v0.19.11 [3783bdb8] TableTraits v1.0.1 [bd369af6] Tables v1.7.0 [8ea1fca8] TermInterface v0.2.3 [ac1d9e8a] ThreadsX v0.1.10 [f269a46b] TimeZones v1.9.0 [a759f4b9] TimerOutputs v0.5.20 [2c80a279] Trajectories v0.2.2 [28d57a85] Transducers v0.4.73 [84d833dd] TransformVariables v0.6.2 [410a4b4d] Tricks v0.1.6 [615932cf] TupleVectors v0.1.5 [3a884ed6] UnPack v1.0.2 [c4a57d5a] UnsafeArrays v1.0.4 [700de1a5] ZygoteRules v0.2.2 [94ce4f54] Libiconv_jll v1.16.1+1 [efe28fd5] OpenSpecFun_jll v0.5.5+0 [f50d1b31] Rmath_jll v0.3.0+0 [02c8fc9c] XML2_jll v2.9.14+0 [0dad84c5] ArgTools [56f22d72] Artifacts [2a0f44e3] Base64 [ade2ca70] Dates [8bb1440f] DelimitedFiles [8ba89e20] Distributed [f43a241f] Downloads [7b1f6079] FileWatching [9fa8497b] Future [b77e0a4c] InteractiveUtils [4af54fe1] LazyArtifacts [b27032c2] LibCURL [76f85450] LibGit2 [8f399da3] Libdl [37e2e46d] LinearAlgebra [56ddb016] Logging [d6f4376e] Markdown [a63ad114] Mmap [ca575930] NetworkOptions [44cfe95a] Pkg [de0858da] Printf [3fa0cd96] REPL [9a3f8284] Random [ea8e919c] SHA [9e88b42a] Serialization [1a1011a3] SharedArrays [6462fe0b] Sockets [2f01184e] SparseArrays [10745b16] Statistics [4607b0f0] SuiteSparse [fa267f1f] TOML [a4e569a6] Tar [8dfed614] Test [cf7118a7] UUIDs [4ec0a83e] Unicode [e66e0078] CompilerSupportLibraries_jll [deac9b47] LibCURL_jll [29816b5a] LibSSH2_jll [c8ffd9c3] MbedTLS_jll [14a3606d] MozillaCACerts_jll [4536629a] OpenBLAS_jll [05823500] OpenLibm_jll [83775a58] Zlib_jll [8e850b90] libblastrampoline_jll [8e850ede] nghttp2_jll [3f19e933] p7zip_jll
cscherrer commented 2 years ago

Hi @AxLamelas , thanks for letting me know about this. I'm checking into it now

cscherrer commented 2 years ago

The logdensity_def problem is easily addressed by adding a primitive for that. I'm still seeing an issue after doing that. This problem can be reduced to

using Soss

mc_init = @model begin
    x ~ Normal(0.0, 1.0)
    return (x=x,)
end

mc_step = @model s begin
    x ~ Normal(s.x + 0.1, 0.2)
    return (x = x,)
end

d = Chain(mc_init()) do s mc_step(s=s) end
x = Iterators.take(rand(d), 10) |> collect
logdensity_def(d,x)

This gives the result

ERROR: MethodError: First argument to `convert` must be a Type, got (x = Float64,)
Stacktrace:
 [1] macro expansion
   @ ~/git/Soss.jl/src/primitives/logdensity.jl:92 [inlined]
 [2] _logdensity_def(M::Type{Soss.GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, _m::Model{NamedTuple{(:p, :s)}, Soss.GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, Soss.GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, _args::NamedTuple{(:s, :p), Tuple{NamedTuple{(:x,), Tuple{Float64}}, NamedTuple{(:Δμ, :σ), Tuple{Float64, Float64}}}}, _data::NamedTuple{(), Tuple{}}, _pars::NamedTuple{(:x,), Tuple{Float64}})
   @ Soss ~/git/Soss.jl/src/primitives/logdensity.jl:92
 [3] logdensity_def(c::Soss.ConditionalModel{NamedTuple{(:p, :s)}, Soss.GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, Soss.GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}, NamedTuple{(:s, :p), Tuple{NamedTuple{(:x,), Tuple{Float64}}, NamedTuple{(:Δμ, :σ), Tuple{Float64, Float64}}}}, NamedTuple{(), Tuple{}}}, x::NamedTuple{(:x,), Tuple{Float64}})
   @ Soss ~/git/Soss.jl/src/primitives/logdensity.jl:51
 [4] logdensity_def(mc::Chain{var"#19#20", Soss.ConditionalModel{NamedTuple{()}, Soss.GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, Soss.GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, x::Vector{Any})
   @ MeasureTheory ~/.julia/packages/MeasureTheory/MeOXc/src/combinators/chain.jl:28
 [5] top-level scope
   @ REPL[89]:1
cscherrer commented 2 years ago

I had hoped this would just be a matter of adding a logdensity_def method. After adding this in https://github.com/cscherrer/Soss.jl/pull/350, I still get the above error,

ERROR: MethodError: First argument to `convert` must be a Type, got (x = Float64,)

We can also look at the simpler example:

julia> mc_step = @model x begin
               y ~ Normal(x + 0.1, 0.2)
               return y
       end;

julia> d = Chain(Normal()) do s mc_step(s) end;

julia> x = Iterators.take(rand(d), 10) |> collect;

julia> logdensity_def(d,x)
ERROR: MethodError: no method matching keys(::Type{Float64})
Closest candidates are:
  keys(::Union{Tables.AbstractColumns, Tables.AbstractRow}) at ~/.julia/packages/Tables/PxO1m/src/Tables.jl:184
  keys(::Missings.EachReplaceMissing) at ~/.julia/packages/Missings/r1STI/src/Missings.jl:94
  keys(::DataStructures.Trie) at ~/.julia/packages/DataStructures/59MD0/src/trie.jl:82
  ...
Stacktrace:
 [1] loadvals(argstype::Type, datatype::Type, parstype::Type)
   @ Soss ~/git/Soss.jl/src/core/utils.jl:217
 [2] #s112#76
   @ ~/git/Soss.jl/src/primitives/logdensity.jl:92 [inlined]
 [3] var"#s112#76"(::Any, M::Any, _m::Any, _args::Any, _data::Any, _pars::Any)
   @ Soss ./none:0
 [4] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
   @ Core ./boot.jl:582
 [5] logdensity_def(c::Soss.ConditionalModel{NamedTuple{(:x,)}, Soss.GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, Soss.GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}, NamedTuple{(:x,), Tuple{Float64}}, NamedTuple{(), Tuple{}}}, x::Float64)
   @ Soss ~/git/Soss.jl/src/primitives/logdensity.jl:50
 [6] logdensity_def(mc::Chain{var"#11#12", Normal{(), Tuple{}}}, x::Vector{Any})
   @ MeasureTheory ~/.julia/packages/MeasureTheory/Mrznp/src/combinators/chain.jl:28
 [7] top-level scope
   @ REPL[24]:1

So it looks like there's a place where types are mistakenly being treated as values. That kind of problem usually means it's a metaprogramming issue, maybe a generated function. But then it's also strange that other tests pass, included those for nested models. So it might have to do with the implementation of Chain, or possible its interaction with the generated functions.

The errors I'm getting are kind of tricky to debug. I think a next step might be for me to implement this example in Tilde and see if I hit the same issue.