cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
414 stars 30 forks source link

Nested traces #227

Closed cscherrer closed 3 years ago

cscherrer commented 3 years ago

Say we have a setup like this:

using Soss

μdist = @model begin
    a ~ Normal()
    b ~ Normal()
    return a/b
end

σdist = @model begin
    x ~ Normal()
    return x^2
end

m = @model begin
    μ ~ μdist
    σ ~ σdist
    x ~ Normal(μ,σ) |> iid(2)
    return x
end

Then currently (in the cs-conditional branch) we can do

julia> rand(m())
2-element Array{Float64,1}:
 -1.0004969741379428
 -0.3560707162245288

and

julia> sample(m())
(value = [-1.0975149021645443, -1.0606655372891374], trace = (σ = 0.6664363771789825, μ = -0.8943549950564752, x = [-1.0975149021645443, -1.0606655372891374]))

But this isn't quite enough. In many cases, we may need to access the internal state of μdist or σdist. For the forward case, it's probably enough to just change, e.g., μ = -0.8943549950564752 to μ = (value = -0.8943549950564752, trace = ...). But outside of this, we'll often need to instead "modify" the result (in the Accessors.jl sense).

So my current thought is that we should carry a context in the form of an Accessors.ComposedOptic. This can be passed around and used like a kind of stack pointer. I still need to work through some details of this, so I'm opening this issue to track any ideas.

cscherrer commented 3 years ago

After a bit of fiddling, it now outputs (with formatting edited to show hierarchy)

julia> sample(m())
( value = 
    [ 2.928731866308371
    , 2.752746485036528]
, trace = 
    ( σ = 
        ( value = 0.05806831281496894
        , trace = 
            ( x = 0.2409736766017586,)
        )
    , μ = 
        ( value = 2.838377243271297
        , trace = 
            ( b = 0.1139546406288632
            , a = 0.32344625872612404)
        )
    , x = 
        [ 2.928731866308371
        , 2.752746485036528]
    )
)

The biggest question is whether this is the right way of nesting the hierarchy.

cscherrer commented 3 years ago

It's working!!

Say you have a model like

m = @model begin
    μ ~ μdist()
    σ ~ σdist()
    x ~ Normal(μ,σ) |> iid(10)
    return x
end

where for some contrived reason you want μdist and σdist to look like

μdist = @model begin
    s ~ Gamma()
    z ~ Normal()
    return sqrt(s)*z
end

σdist = @model begin
    x ~ Normal()
    return abs(x)
end

Then you can do

julia> rand(m())
10-element Array{Float64,1}:
 0.8987755216091674
 1.5404884435009312
 1.49674830638398
 0.836414450378643
 1.9763714945277209
 1.335772718161111
 0.9799369872413912
 1.3145970762617967
 1.7797988916041363
 1.7669977202709188

and

julia> sample(m())
(value = [0.16270957974862998, -0.39949227827205963, -0.366509595152026, 0.510598313387259, -0.188267753941064, -0.029582640615416114, 0.40188976052333547, -0.05958940382682634, -0.2242093224518913, -0.013685083779422232], trace = (x = [0.16270957974862998, -0.39949227827205963, -0.366509595152026, 0.510598313387259, -0.188267753941064, -0.029582640615416114, 0.40188976052333547, -0.05958940382682634, -0.2242093224518913, -0.013685083779422232], μ = (value = -0.03670340697988196, trace = (s = 0.05228547528055071, z = -0.16051508358592426)), σ = (value = 0.2814602275667513, trace = (x = 0.2814602275667513,))))

And HMC now works too!

julia> dynamicHMC(m() | (;x))
1000-element Array{NamedTuple{(:σ, :μ),Tuple{NamedTuple{(:x,),Tuple{Float64}},NamedTuple{(:z, :s),Tuple{Float64,Float64}}}},1}:
 (σ = (x = -2.7509614123288353,), μ = (z = 1.2507649275611596, s = 0.03573972089076121))
 (σ = (x = -2.5687359737754343,), μ = (z = 0.9475268575919211, s = 0.04227077853218441))
 (σ = (x = -1.3861789998350145,), μ = (z = -0.8038506706517361, s = 0.09821639861100565))
 (σ = (x = -1.3086421505430927,), μ = (z = -0.7710964208992048, s = 0.25829860363376733))
 (σ = (x = -1.305863990283008,), μ = (z = -0.7665432602037672, s = 0.3275944391695731))
 (σ = (x = -0.5641860250119911,), μ = (z = -0.7207446140773045, s = 0.380683944856869))
 (σ = (x = -0.658716030360198,), μ = (z = -0.5087355557531692, s = 0.17932308015909262))
 (σ = (x = -0.7664156061468581,), μ = (z = 0.5248257097598521, s = 0.13849062959075062))
 (σ = (x = -1.2727839678529427,), μ = (z = -0.3236480229772443, s = 0.2255534288204248))
 (σ = (x = -0.7155322490922567,), μ = (z = -0.6003604485802628, s = 1.101525448375478))
 (σ = (x = -0.7155322490922567,), μ = (z = -0.6003604485802628, s = 1.101525448375478))
 (σ = (x = -1.2472845602960045,), μ = (z = -0.2629837817403453, s = 0.0786623919543633))
 (σ = (x = -1.0041397554418758,), μ = (z = -0.4988484034541389, s = 2.264916433115464))
 (σ = (x = -0.8221055602791598,), μ = (z = -0.038824977188513354, s = 1.831235456990119))
 (σ = (x = -0.8146719070081059,), μ = (z = 0.848170599016523, s = 0.24827259063819362))
 (σ = (x = -0.9251644989435468,), μ = (z = 0.05331572097035464, s = 0.49484614636277857))
 (σ = (x = -1.1427576200850518,), μ = (z = -0.9145305802773287, s = 0.014539699386405482))
 (σ = (x = -1.5098415887084335,), μ = (z = 2.7104443261756868, s = 0.11762893415392027))
 (σ = (x = -0.8025155214235542,), μ = (z = 1.54693885061003, s = 0.0745790320610713))
 (σ = (x = -1.7908528242636836,), μ = (z = -0.9759534377545748, s = 0.09715515525703333))
 (σ = (x = -1.243438792364528,), μ = (z = 1.0680398024268236, s = 1.8415845124857169))
 (σ = (x = -1.41927338596261,), μ = (z = 0.9207954835068637, s = 0.8630895344548597))
 (σ = (x = -0.5654558925463217,), μ = (z = -0.4853975002235683, s = 0.26208180594904756))
 (σ = (x = -1.3038596674421752,), μ = (z = -1.929114873137208, s = 0.149916016902505))
 (σ = (x = -0.5286399824610131,), μ = (z = 0.1612197314920436, s = 0.17309553902367122))
 (σ = (x = -1.0658995614214306,), μ = (z = 0.23768244842150352, s = 0.4844050998447127))
 (σ = (x = -0.7933982423009571,), μ = (z = 0.11856812939246153, s = 1.2459628483567358))
 (σ = (x = -1.01058818044223,), μ = (z = 0.04558806030967599, s = 2.4583490260226406))
 (σ = (x = -0.4593376244722846,), μ = (z = -0.23592164129773735, s = 0.7260528535949017))
 (σ = (x = -1.09148191346783,), μ = (z = 0.25460471958841147, s = 1.3484374265272665))
 ⋮
 (σ = (x = -0.5801987204044089,), μ = (z = -0.9028133096618134, s = 0.15825090582374007))
 (σ = (x = -1.6416239788694245,), μ = (z = -0.5491305107889446, s = 1.0777608321670948))
 (σ = (x = -1.354378808723646,), μ = (z = 0.6399120867010685, s = 0.8270024785454115))
 (σ = (x = -1.6460183372586923,), μ = (z = -0.4738902763473705, s = 1.5589422057009745))
 (σ = (x = -1.0338260883824755,), μ = (z = -0.1532483837390508, s = 1.8386672636258248))
 (σ = (x = -0.9224015631855659,), μ = (z = -0.33994591755802017, s = 0.48104048852700126))
 (σ = (x = -0.7643085362668022,), μ = (z = 0.21162402886647913, s = 0.5163935571258168))
 (σ = (x = -0.8961478783321771,), μ = (z = -0.350820067866273, s = 1.5844804419621772))
 (σ = (x = -0.9205123331902257,), μ = (z = -0.22845740070594442, s = 1.4611594427003478))
 (σ = (x = -0.6392489317206853,), μ = (z = 0.31346142494080403, s = 0.19179304967981683))
 (σ = (x = -0.6860500949057662,), μ = (z = -0.0749823529477518, s = 0.021647319201071127))
 (σ = (x = -0.5929718337811852,), μ = (z = -1.1849421419789057, s = 0.03350051727314387))
 (σ = (x = -0.9074446905960427,), μ = (z = -0.7382744284589017, s = 0.033041157639288846))
 (σ = (x = -0.6319008473108123,), μ = (z = 0.665575175599197, s = 0.039572184074662044))
 (σ = (x = -0.7315791547633603,), μ = (z = 0.4845668407278135, s = 0.02823724087075678))
 (σ = (x = -1.0703816084178646,), μ = (z = -0.17517358290605448, s = 2.90800490155679))
 (σ = (x = -1.2325195390166057,), μ = (z = 1.3150084273685767, s = 0.9694984611678928))
 (σ = (x = -1.4350754133097698,), μ = (z = -0.16194385491160723, s = 0.21705244976726698))
 (σ = (x = -1.613687462971647,), μ = (z = -0.0550397509133167, s = 0.3997633159705947))
 (σ = (x = -1.4295068396473496,), μ = (z = 0.7035611290413929, s = 2.629901808498576))
 (σ = (x = -0.8998257325280834,), μ = (z = 0.4000272463858731, s = 0.4910378703606944))
 (σ = (x = -0.6375335991235168,), μ = (z = -0.34248704204602204, s = 1.6908282621357895))
 (σ = (x = -0.7542527056777214,), μ = (z = -0.11122049702504266, s = 0.40096078936903895))
 (σ = (x = -2.224882099296511,), μ = (z = 0.4395151749031513, s = 0.20015083955560758))
 (σ = (x = -1.4301415437133067,), μ = (z = -0.20865683782474426, s = 2.6701064572181776))
 (σ = (x = -2.008019402342103,), μ = (z = 0.3775811025661532, s = 0.478624380940115))
 (σ = (x = -1.518486566598051,), μ = (z = 0.6678005041626562, s = 0.5883842553605796))
 (σ = (x = -1.2603705249945103,), μ = (z = 0.49206280591373674, s = 0.5766827513214688))
 (σ = (x = -1.5136775656207444,), μ = (z = -0.04881621212207399, s = 0.9483718240967165))
 (σ = (x = -1.116558028117766,), μ = (z = -0.10760514267094609, s = 0.5759118851685311))