pyro-ppl / funsor

Functional tensors for probabilistic programming
https://funsor.pyro.ai
Apache License 2.0
236 stars 20 forks source link

Switch to sqrt(prescision) representation in Gaussian? #567

Closed fritzo closed 3 years ago

fritzo commented 3 years ago

Addresses #559

This issue proposes switching the Gaussian parameters from (info_vec, precision) to (info_vec, prec_sqrt), following @fehiepsi's work in https://github.com/pyro-ppl/pyro/pull/2019.

Motivation

Our original motivation for representing Gaussians as (info_vec, precision) was to support operations on rank-deficient precision matrices, which occur in conditional probability distributions. This design choice allows us to uniformly handle priors, conditional distributions, and likelihoods by treating all three agnostically as mere factors in a factor graph.

However while the (info_vec, precision) representation is numerically stable, it is computationally inefficient when making low-dimensional observations of a high-dimensional variable. For example to store a conditional distribution of a 1-dimensional variable given a 1000-dimensional variable, the precision matrix has 1001**2 elements, but since it has rank 1 its square root would cost only 1001 elements. Indeed we recognized this https://github.com/pyro-ppl/pyro/pull/2005 and https://github.com/pyro-ppl/funsor/pull/217 and created a special AffineNormal pattern to avoid materializing rank-1 precision matrices.

An alternative representation is the classic square root information filter (SRIF), explored by @fehiepsi in https://github.com/pyro-ppl/pyro/pull/2019. This represents a Gaussian as a pair (info_vec, prec_sqrt), of shapes batch_shape + (dim,) and batch_shape + (dim, rank) respectively, so that

precision = prec_sqrt @ prec_sqrt.transpose(-2, -1)

Advantages of the square root information representation include:

Design questions