pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.51k stars 984 forks source link

[FR] add an alternative parameterization for lower triangular cholesky factors #2924

Open martinjankowiak opened 3 years ago

martinjankowiak commented 3 years ago

add an alternative parameterization for lower triangular cholesky factors and consider using this in AutoMultivariateNormal:

instead of parameterizing a lower cholesky factor as an unconstrained strictly lower triangular piece and a positive diagonal we instead parameterize as

L = unit_scale_tril @ scale_diag

where unit_scale_tril is lower triangular with ones along the diagonal and scale_diag is a positive diagonal matrix.

for more details see the corresponding NumPyro PR

fritzo commented 3 years ago

I agree we should improve the parametrization of AutoNormal. We did something similar to AutoLowRankMultivariateNormal #2127, and indeed AutoLowRankMultivariateNormal now works much better than AutoMultivariateNormal.

fritzo commented 3 years ago

Note I've been using a similar overparametrization in the AutoStructured guide, since the raw MVN via scale_tril performed so poorly: https://github.com/pyro-ppl/pyro/blob/c8dc40a75cc4ff1f43c6ff9178d91c08155d7973/pyro/infer/autoguide/guides.py#L1564-L1573

martinjankowiak commented 3 years ago

interesting. note this isn't an overparametrization though: the diagonal on unit_scale_tril is all ones