pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.1k stars 227 forks source link

Add auto-batched (low-rank) multivariate normal guides. #1737

Closed tillahoffmann closed 5 months ago

tillahoffmann commented 5 months ago

This PR implements auto-guides that support batching along leading dimensions of the parameters. The guides are motivated by models that have conditional independence structure but possibly strong correlation within each instance of a plate. The interface is exactly the same as Auto[LowRank]MultivariateNormal with an additional argument batch_ndims that specifies the number of dimensions to treat as independent in the posterior approximation.

Example

Consider a random walk model with n time series and t observations. Then the number of parameters is n * t + 2 * n (matrix of latent time series, one scale parameter for random walk innovations for each series, and one scale parameter for observation noise for each series). For concreteness, here's the model.

def model(n, t):
    with numpyro.plate("n", n):
        # Model for time series.
        innovation_scale = numpyro.sample(
            "innovation_scale",
            distributions.HalfCauchy(1),
        )
        innovations = numpyro.sample(
            "innovations",
            distributions.Normal().expand([t]).to_event(1),
        )
        series = numpyro.deterministic(
            "series",
            innovations.cumsum(axis=-1),
        )

        # Model for observations.
        noise_scale = numpyro.sample(
            "noise_scale",
            distributions.HalfCauchy(1),
        )
        data = numpyro.sample(
            "data",
            distributions.Normal(series, noise_scale[:, None]).to_event(1),
        )

Suppose we use different auto-guides and count the number of parameters we need to optimize. The example below is for n = 10 and t = 20

# [guide class] [total number of parameters]
#   [parameter shapes]
AutoDiagonalNormal 440
     {'auto_loc': (220,), 'auto_scale': (220,)}
AutoLowRankMultivariateNormal 3740
     {'auto_loc': (220,), 'auto_cov_factor': (220, 15), 'auto_scale': (220,)}
AutoMultivariateNormal 48620
     {'auto_loc': (220,), 'auto_scale_tril': (220, 220)}
AutoBatchedLowRankMultivariateNormal 1540
     {'auto_loc': (10, 22), 'auto_cov_factor': (10, 22, 5), 'auto_scale': (10, 22)}
AutoBatchedMultivariateNormal 5060
     {'auto_loc': (10, 22), 'auto_scale_tril': (10, 22, 22)}

AutoDiagonalNormal of course has the fewest parameters and AutoMultivariateNormal the most. The number of location parameters is the same across all guides. The batched versions have significantly fewer scale/covariance parameters (but of course cannot model dependence between different series). There is no free lunch, but I believe these batched guides can strike a reasonable compromise between modeling dependence and computational cost.

Implementation

The implementation uses a mixin AutoBatchedMixin to

  1. determine the batch shape (and verify that a batched guide is appropriate for the model) and
  2. apply a reshaping transformation to account for the existence of batches in the variational approximation.

The two batched guides are implemented analogously to the non-batched guides with the addition of the mixin and slight modifications to the parameters.

I added a ReshapeTransform to take care of the shapes. That could probably also be squeezed into the UnpackTransform. I decided on the former approach because

  1. it separates the concerns rather than packing more logic into UnpackTransform and
  2. I've found myself looking for reshaping samples in other settings.

[!note] I didn't implement the get_base_dist, get_transform, and get_posterior methods because I couldn't find the corresponding tests.

tillahoffmann commented 4 months ago

It turns out that for larger datasets, we run into https://github.com/google/jax/issues/19885. The issue could probably be worked around in numpyro by slightly rearranging operations in the LowRankMultivariateNormal implementation. Is that of interest or just wait for the upstream fix (I don't know how quick the jax folks usually are)?

fehiepsi commented 4 months ago

Oh, what a subtle issue. It would be nice to have a fix here (if the solution is simple like changing operators around)