Closed fritzo closed 3 years ago
@eb8680 could we pair code on Gaussian patterns this week? I think this PR now has correct linear algebra, but it produces slightly different patterns. Specifically, the new square root representation can no longer be negated, so we'll need to keep Unary(ops.neg, Gaussian)
lazy.
This is currently failing a Pyro AutoGaussian pyro-cov test, where the optimized Gaussian is
and backward sampling erroneously tries to call ._sample({"init_loc"}, {})
on this Gaussian:
Gaussian(
torch.tensor(...5..., dtype=torch.float64),
torch.tensor(...6 x 5..., dtype=torch.float64),
(('init_loc', Reals[5]),
('init_loc_scale__BOUND_36', Real),))
EDIT this was resolved by ensuring marginalization preserves non-reduced inputs.
@eb8680 could you please review this PR in general? @fehiepsi could you please review the linear algebra?
I'm happy to walk you through the changes over zoom.
Whoa, impressive work to make this possible! I haven't looked into the code yet but my general concern would be on marginalization and compress rank logics. I will look into the details later of this week.
Thanks for reviewing, @fehiepsi! Here are a few weak reasons I chose white_vec
instead of info_vec
:
white_vec
seems natural π white_vec
is space optimal in the low-rank case, with white_vec.shape[-1] == rank
versus info_vec.shape[-1] == dim
. E.g. in the rank-1 case we get up to a factor of two savings: O(rank(1+dim)) < O(dim(1+rank))
.white_vec
does not depend on the real inputs, it need not be rearranged when interleaving or substituting real variables. This is minor, but it does simplify code a bit.Thanks for reviewing @eb8680 and @fehiepsi! I believe I've addressed all comments.
I plan to add more patterns in subsequent PRs that handle Gaussian variable elimination, e.g. in #553 and https://github.com/pyro-ppl/funsor/tree/tractable-for-gaussians
Resolves #567 Adapts @fehiepsi's https://github.com/pyro-ppl/pyro/pull/2019
This switches the internal Gaussian representation to a numerically stable and space efficient representation
In the new parametrization, Gaussians represent the log-density function
These two parameters are shaped to efficiently represent low-rank data:
reducing space complexity from
O(dim(dim+1))
toO(rank(dim+1))
. In my real-world example rank=1, dim=2369, and batch_shape=(1343,), so the space reduction is 30GB β 13MB.Computations is cheap in this representation: addition amounts to concatenation, and plate-reduction amounts to transpose and reshape. Some ops are only supported on full-rank Gaussians, and I've added checks based on the new property
.is_full_rank
. This partial support is ok because the Gaussian funsors that arise in Bayesian models are all full rank due to priors (notwithstanding numerical loss of rank).Because the Gaussian funsor is internal, the interface change should not cause breakage of most user code, since most user code uses
to_funsor()
andto_data()
with backend-specific distributions. One broken piece of user code is Pyro's AutoGaussianFunsor which will need an update (and which will be sped up).As suggested by @eb8680, I've added some optional kwarg parametrizations and properties to support conversion to other Gaussian representations, e.g.
g = Gaussian(mean=..., covariance=..., inputs=...)
andg._mean
,g._covariance
. This allows more Gaussian math to live in gaussian.py.Tested