probcomp / Gen.jl

A general-purpose probabilistic programming system with programmable inference
https://gen.dev
Apache License 2.0
1.8k stars 160 forks source link

Add truncated normal distribution #353

Open marcoct opened 3 years ago

femtomc commented 3 years ago

I'll take a swing at this today. Looks cool.

ztangent commented 3 years ago

@femtomc I wonder if it might be worth making a generalized truncated operation, similar to: https://juliastats.org/Distributions.jl/stable/truncate/

ztangent commented 3 years ago

We could do a fallback implementation like:

function _logpdf(d::Truncated, x::T) where {T<:Real}
    if d.lower <= x <= d.upper
        logpdf(d.untruncated, x) - d.logtp
    else
        TF = float(T)
        -TF(Inf)
    end
end

and specialized implementations for common truncated distributions like the truncated normal.

femtomc commented 3 years ago

I was discussing this a bit with Alex -- will comment in a bit.

femtomc commented 3 years ago

I agree with your comment. Basically, start with truncated normal. As we build more, you can specialize and have a function truncate which mimics the Distributions.jlmethod truncated.

ztangent commented 3 years ago

I was actually thinking the other way round! But I realized on further thought that a fallback implementation won't be as easy, just because we don't equip Gen distributions with CDFs like Distributions.jl does, and that means we can't calculate d.logtp for every distribution easily.

marcoct commented 3 years ago

See https://github.com/probcomp/Gen.jl/issues/362