probcomp / Gen.jl

A general-purpose probabilistic programming system with programmable inference
https://gen.dev
Apache License 2.0
1.8k stars 160 forks source link

`logpdf_grad` errors for `HomogeneousMixture` #445

Closed fzaiser closed 2 years ago

fzaiser commented 2 years ago

The following example crashes:

using Gen

@gen function test()
    mix = HomogeneousMixture(broadcasted_normal, [1, 0])
    means = hcat([0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [3.0, 3.0])
    @trace(mix([0.25, 0.25, 0.25, 0.25], means, [0.1, 0.1, 0.1, 0.1]), :x)
end
trace = Gen.simulate(test, ())
result = Gen.hmc(trace, selectall())

It throws the following error:

ERROR: LoadError: DimensionMismatch("new dimensions (1, 2) must be consistent with array size 4")
 [1] (::Base.var"#throw_dmrsa#196")(::Tuple{Int64,Int64}, ::Int64) at ./reshapedarray.jl:41
 [2] reshape at ./reshapedarray.jl:45 [inlined]
 [3] reshape(::Array{Float64,1}, ::Int64, ::Int64) at ./reshapedarray.jl:116
 [4] logpdf_grad(::HomogeneousMixture{Array{Float64,N} where N}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,2}, ::Array{Float64,1}) at [...]/packages/Gen/[...]/src/modeling_library/mixture.jl:115
...

I believe the reason is that in the line https://github.com/probcomp/Gen.jl/blob/fa759d399d97c745b15416db6bcb5d8701bd29e6/src/modeling_library/mixture.jl#L117 length(dist.dims) should be replaced by K. This removes the exception, but I don't understand the code well enough to be sure that this is the correct fix or whether other parts of the code have to be fixed too.

bzinberg commented 2 years ago

Haven't looked in depth, but I suspect this is indeed due to an assumption somewhere that args to distributions will be flat, i.e. cannot be array-valued. The use of length instead of size/axes looks suspect to me.

In general, the args to a distribution could be arrays of different shapes. I'm not aware of us having general machinery for flattening and unflattening arrays in the gradient operations (nor am I sure that flattening and unflattening is the right thing to do, necessarily).

bzinberg commented 2 years ago

(Oops, misread a doc. Deleted comment.)

fzaiser commented 2 years ago

@bzinberg Thanks for the quick reply! In the documentation for HomogeneousMixture, there is an example with a multivariate normal distribution, which takes a mean vector and a covariance matrix (i.e. different shapes for the two arguments). Therefore, I thought it was supported. Do you think this functionality would be difficult to implement?

alex-lew commented 2 years ago

Hi @fzaiser! I think a lot of us were on winter break when you posted this and it fell through the cracks -- sorry about that!

I think you're right that the length(dist.dims) on that line should be replaced by K, the number of components. Thanks for tracking this down and finding (then filing) the bug!

(As an aside, HMC will struggle to explore multiple modes in this target — but I think that may be the point of the experiment :).)

As Ben mentioned, there are parts of Gen (including the @dist DSL) that make certain restrictive assumptions about data shapes, but I don't think you should run into that on this example.

fzaiser commented 2 years ago

Hi @alex-lew, no problem and thanks for the fix! I hope to have some time to experiment with it soon. Indeed, I'm aware of HMC struggling with such a multi-modal distribution. :) I was just playing around with gradient-based inference methods when I hit the bug and HMC was the simplest way to reproduce it.