Open martinjankowiak opened 3 years ago
I agree we should improve the parametrization of AutoNormal
. We did something similar to AutoLowRankMultivariateNormal
#2127, and indeed AutoLowRankMultivariateNormal
now works much better than AutoMultivariateNormal
.
Note I've been using a similar overparametrization in the AutoStructured
guide, since the raw MVN via scale_tril
performed so poorly:
https://github.com/pyro-ppl/pyro/blob/c8dc40a75cc4ff1f43c6ff9178d91c08155d7973/pyro/infer/autoguide/guides.py#L1564-L1573
interesting. note this isn't an overparametrization though: the diagonal on unit_scale_tril
is all ones
add an alternative parameterization for lower triangular cholesky factors and consider using this in
AutoMultivariateNormal
:instead of parameterizing a lower cholesky factor as an unconstrained strictly lower triangular piece and a positive diagonal we instead parameterize as
L = unit_scale_tril @ scale_diag
where
unit_scale_tril
is lower triangular with ones along the diagonal andscale_diag
is a positive diagonal matrix.for more details see the corresponding NumPyro PR