Closed fzaiser closed 2 years ago
Haven't looked in depth, but I suspect this is indeed due to an assumption somewhere that args to distributions will be flat, i.e. cannot be array-valued. The use of length
instead of size
/axes
looks suspect to me.
In general, the args to a distribution could be arrays of different shapes. I'm not aware of us having general machinery for flattening and unflattening arrays in the gradient operations (nor am I sure that flattening and unflattening is the right thing to do, necessarily).
(Oops, misread a doc. Deleted comment.)
@bzinberg Thanks for the quick reply! In the documentation for HomogeneousMixture
, there is an example with a multivariate normal distribution, which takes a mean vector and a covariance matrix (i.e. different shapes for the two arguments). Therefore, I thought it was supported. Do you think this functionality would be difficult to implement?
Hi @fzaiser! I think a lot of us were on winter break when you posted this and it fell through the cracks -- sorry about that!
I think you're right that the length(dist.dims)
on that line should be replaced by K
, the number of components. Thanks for tracking this down and finding (then filing) the bug!
(As an aside, HMC will struggle to explore multiple modes in this target — but I think that may be the point of the experiment :).)
As Ben mentioned, there are parts of Gen (including the @dist
DSL) that make certain restrictive assumptions about data shapes, but I don't think you should run into that on this example.
Hi @alex-lew, no problem and thanks for the fix! I hope to have some time to experiment with it soon. Indeed, I'm aware of HMC struggling with such a multi-modal distribution. :) I was just playing around with gradient-based inference methods when I hit the bug and HMC was the simplest way to reproduce it.
The following example crashes:
It throws the following error:
I believe the reason is that in the line https://github.com/probcomp/Gen.jl/blob/fa759d399d97c745b15416db6bcb5d8701bd29e6/src/modeling_library/mixture.jl#L117
length(dist.dims)
should be replaced byK
. This removes the exception, but I don't understand the code well enough to be sure that this is the correct fix or whether other parts of the code have to be fixed too.