cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
414 stars 30 forks source link

Help fitting a simple t distribution #304

Closed cwoode closed 2 years ago

cwoode commented 2 years ago
using MeasureTheory
using Soss
using SampleChainsDynamicHMC
using Random
import Distributions as Dist

model_test_t = @model T begin
    ν ~ Exponential(10.0)
    y ~ StudentT(ν) |> iid(T)
end;

y_test = rand(Dist.TDist(2.5),1000);

post_test = sample(model_test_t(T=length(y_test))|(y=y_test,), dynamichmc())

Gives the following:

4000-element MultiChain with 4 chains and schema (ν = Float64,)
(ν = 155.0±51.0,)

I can't seem to get this to work. Also, any hint on getting advancedHMC to work with the current release? dynamichmc tends to abort on any sampling errors.

cscherrer commented 2 years ago

Thanks for letting me know about this. The problem is that Soss currently uses logdensity, which for StudentT is defined as


function logdensity(d::StudentT{(:ν,)}, x) 
    ν = d.ν
    return  xlog1py((ν + 1) / (-2), x^2 / ν)
end

function basemeasure(d::StudentT{(:ν,)})
    inbounds(x) = true
    constℓ = 0.0
    varℓ() = loggamma((d.ν+1)/2) - loggamma(d.ν/2) - log(π * d.ν) / 2
    base = Lebesgue(ℝ)
    FactoredBase(inbounds, constℓ, varℓ, base)
end

If ν is constant, there's no need to compute the normalizing constant loggamma((d.ν+1)/2) - loggamma(d.ν/2) - log(π * d.ν) / 2. But in this case we need it, or the result is wrong.

I need to make Soss more intelligent about tracking this sort of thing, but for now "correct" is more important than "fast". So for a quick fix, I'll tag a new release that uses logpdf instead of logdensity.

Also, I just noticed that I'm using a different default parameterization than Distributions. I'll keep Exponential{(:λ,)} as the "rate" parameterization, but I need to add a "scale" parameterization too, and... maybe that should be the default? I mostly try to match Distributions.jl for defaults where it makes sense.

For now, if I change your ν ~ Exponential(10.0) to ν ~ Exponential(λ=0.1) (making the rate parameter explicit) and make Soss use logpdf, I can do

julia> function f(ν,T=1000)
           y_test = rand(Dists.TDist(ν),T)
           post_test = sample(m(T=T) | (y=y_test,), dynamichmc())
       end
f (generic function with 2 methods)

julia> using TupleVectors: summarize

julia> for ν ∈ (0.1,0.2,0.5,1.0,2.0,5.0,10.0)
           println(ν," => ", summarize(f(ν)))
       end
0.1 => (ν = 0.05077±0.0024,)
0.2 => (ν = 0.1084±0.0054,)
0.5 => (ν = 0.2514±0.014,)
1.0 => (ν = 0.496±0.032,)
2.0 => (ν = 0.951±0.08,)
5.0 => (ν = 1.91±0.22,)
10.0 => (ν = 3.42±0.61,)

Those numbers look suspicious to me, so I'm not sure yet that the fix is complete. I'll need to do some more checking.

I'd love to get AdvancedHMC working, and it's one of those things that's "not hard in principle", and would probably just take a few days. It's just that there are so many of those :)

cwoode commented 2 years ago

Thanks for the quick response. Yes, it doesn't look right. This is an incredible project btw.

cscherrer commented 2 years ago

Ok, think I got it:

julia> for ν ∈ (0.1,0.2,0.5,1.0,2.0,5.0,10.0)
           println(ν," => ", summarize(f(ν)))
       end
0.1 => (ν = 0.1015±0.0033,)
0.2 => (ν = 0.1965±0.0069,)
0.5 => (ν = 0.4939±0.019,)
1.0 => (ν = 1.008±0.047,)
2.0 => (ν = 1.99±0.13,)
5.0 => (ν = 5.06±0.6,)
10.0 => (ν = 11.9±2.8,)

and thanks!

cwoode commented 2 years ago

This is working now. Thanks

4000-element MultiChain with 4 chains and schema (ν = Float64,)
(ν = 2.183±0.15,)