Closed amifalk closed 3 months ago
Thanks @amifalk! This is a bug because we allow plate
to be applied to the unconstrained value: https://github.com/pyro-ppl/numpyro/blob/b16741cc163b1a3753a331e3200c64cced9eb804/numpyro/infer/reparam.py#L283-L286
A temporary fix is to remove plate for the first site
P_cov = numpyro.sample('P_cov', dist.InverseGamma(3, 1).expand([k]).to_event())
Minimal example:
I'm not entirely sure what's going on here. The following model works with vanilla NUTS, but returns
TypeError: mul got incompatible shapes for broadcasting: (3, 5), (5, 5)
when trying to run NUTS after reparameterizing with NeuTraReparam.If I remove the top two plates and replace the latents with the constants
the code runs but I get the following warnings:
Maybe it has something to do with having multiple plate names with the same dimension?