Closed tillahoffmann closed 5 months ago
It turns out that for larger datasets, we run into https://github.com/google/jax/issues/19885. The issue could probably be worked around in numpyro
by slightly rearranging operations in the LowRankMultivariateNormal
implementation. Is that of interest or just wait for the upstream fix (I don't know how quick the jax folks usually are)?
Oh, what a subtle issue. It would be nice to have a fix here (if the solution is simple like changing operators around)
This PR implements auto-guides that support batching along leading dimensions of the parameters. The guides are motivated by models that have conditional independence structure but possibly strong correlation within each instance of a
plate
. The interface is exactly the same asAuto[LowRank]MultivariateNormal
with an additional argumentbatch_ndims
that specifies the number of dimensions to treat as independent in the posterior approximation.Example
Consider a random walk model with
n
time series andt
observations. Then the number of parameters isn * t + 2 * n
(matrix of latent time series, one scale parameter for random walk innovations for each series, and one scale parameter for observation noise for each series). For concreteness, here's the model.Suppose we use different auto-guides and count the number of parameters we need to optimize. The example below is for
n = 10
andt = 20
AutoDiagonalNormal
of course has the fewest parameters andAutoMultivariateNormal
the most. The number of location parameters is the same across all guides. The batched versions have significantly fewer scale/covariance parameters (but of course cannot model dependence between different series). There is no free lunch, but I believe these batched guides can strike a reasonable compromise between modeling dependence and computational cost.Implementation
The implementation uses a mixin
AutoBatchedMixin
toThe two batched guides are implemented analogously to the non-batched guides with the addition of the mixin and slight modifications to the parameters.
I added a
ReshapeTransform
to take care of the shapes. That could probably also be squeezed into theUnpackTransform
. I decided on the former approach becauseUnpackTransform
and