probcomp / Gen.jl

A general-purpose probabilistic programming system with programmable inference
https://gen.dev
Apache License 2.0
1.79k stars 160 forks source link

Type signature of HeterogeneousMixture #500

Closed mirkoklukas closed 1 year ago

mirkoklukas commented 1 year ago

Gen's HeterogeneousMixture expects Vector{Distribution{T}} and doesn't know what to do with Vector{D} where D "inherits" from Distribution{T}. Note that feeding [uniform, normal] is fine since it is converted to the common "supertype" and is treated as Vector{Distribution{T}}. However [D(),D()] is treated as Vector{D}.

This is an issue when you use GenDistributions.jl's DistributionsBacked or my PushForward type for instance.

The signature should be changed as follows:

# Current signature
HeterogeneousMixture(distributions::Vector{Distribution{T}}) where {T}
# Better signature, solving the issue.
HeterogeneousMixture(distributions::Vector{D}) where {T, D <: Distribution{T}}
mirkoklukas commented 1 year ago

Example. Code below throws MethodError: no method matching HeterogeneousMixture(::Vector{DistributionsBacked}).

using .GenDistributions
using Distributions
using Gen

const dirichlet = DistributionsBacked(alpha -> Dirichlet(alpha), (true,), true, Vector{Float64})
const flip      = DistributionsBacked(p -> Bernoulli(p), (true,), false, Bool) 

mix = HeterogeneousMixture([dirichlet,flip])