Closed hawkrobe closed 7 years ago
Thanks for the report and the reproducible example! I'm looking into it.
@fritzo : thanks! @jpchen has also been working with these models a fair amount and may have thoughts as well.
yep working wit @fritzo on this
After some deep diving with @jpchen, we found that your model and guide that had different parameter shapes at the pyro.sample('intercept', Normal(...))
site. The fix is to change your model parameter shapes
subj_bias = pyro.sample('intercepts',
- Normal(b0.expand(num_subjects),
- sigma_subj.expand(num_subjects)))
+ Normal(b0.expand(num_subjects, 1),
+ sigma_subj.expand(num_subjects, 1)))
Sorry for such a difficult-to-diagnose error. We are adding an error message for this case so that debugging will be easier in the future (see #303). Let us know if this works for you.
@fritzo @jpchen : wow, that's super subtle (and a really surprising consequence -- I would've expected it to either throw an error or noticeably mess the whole thing up instead of just making the loss/uncertainty converge to different numbers!)
Thanks for taking the time to diagnose, and glad it's not a deeper issue!
Consider the vectorized and non-vectorized versions of the same hierarchical regression model. The main differences are that the vectorized version:
iarange
instead ofirange
observe
from a1 x batch_size
dimensional distribution instead ofbatch_size
separate observes from 1 dimensional distributions1 x k
dimensional params for a singleintercepts
sample site instead of k separate 1 dimensional params for k sample sitesIntuitively, I'd expect these to converge to the same loss; instead the vectorized version converges to ~1900 and the non-vectorized version converges to ~600. This same model written in webppl also converges to about ~600, so this might indicate an issue with scaling in the vectorized version?
Incidentally, the mean-field guide
sigma
s also converge to different values in the vectorized version although themu
point estimates are the same. The unvectorized version matches webppl but the vectorized version has much higher certainty.It's of course very plausible that there's a bug in my implementation of the vectorized model!