pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.13k stars 232 forks source link

Support weighted statistics in `summary` diagnostics #810

Open fehiepsi opened 3 years ago

fehiepsi commented 3 years ago

Currently, summary returns ordinary statistics mean/var/quantiles/n_eff for MCMC samples. It would be nice to extend that utility to calculate weighted statistics. This would be useful for important sampling or nested sampling.

Proposal implementation: add a keyword log_weights (or weights) to summary:

def summary(samples, prob=0.90, group_by_chain=True, log_weights=None):
    ...

If log_weights is not None, then we return weighted mean, weighted var, weighted quantiles (instead of hpdi?), effective sample size for weighted samples, and skip r_hat.

fritzo commented 3 years ago

More generally I think bags-of-samples are a more flexible data structure and could be used throughout Pyro & NumPyro. Marco & Vikash said they found bags-of-samples to be the right datatype to communicate between nested inference algorithms in Gen. And @eb8680 and I have found in Funsor that the output of Monte Carlo sampling is most cleanly represented by Delta distributions that can be weighted (and whose weights and values are both differentiable).

Padarn commented 3 years ago

The idea of having a bag-of-samples data structure seems nice. Would you want to reuse something from funsors? Or Just create something simple for Pyro and NumPyro to use?

fehiepsi commented 3 years ago

@fritzo Could you suggest a way to implement this? I am not familiar with "bag-of-samples"...

Padarn commented 3 years ago

I figured he was referring to what is done in NumPyro, where data points are represented at Delta distributions so that a sample can be represented naturally as a mixture distribution of these Deltas? (of course, correct me if I am wrong, just what I gathered from the Funsor paper: https://arxiv.org/abs/1910.10775).

fritzo commented 3 years ago

Could you suggest a way to implement this?

Sorry I missed this. In many places we use sets of samples encoded as a batched tensor of shape (num_samples,) + event_shape. You can think of this as encoding an implicitly uniformly-weighted bag of samples, where the weight of each sample is 1 / num_samples. More generally you could create a data structure with both weights and samples. We have already implemented bags of samples of single sites as dist.Empirical in Pyro. You could do something similar with a vectorized trace or sample dict by adding a weights or log_weights tensor of shape (num_samples,), e.g. something like

class BagOfSamples:
    def __init__(self, weights: np.ndarray, samples: dict):
        assert len(weights.shape) == 1
        for value in samples.values():
            assert value.shape[:1] == weights.shape
        self.weights = weights
        self.samples = samples

Actually I think the Predictive class could be generalized in this direction.

fehiepsi commented 3 years ago

Thanks @fritzo and @Padarn! I got what you mean now.

@Padarn Do you want to take a stab on this? We can start with samples is an array (which mimics Empirical) but it can be generalized easily to arbitrary pytree by using some jax utility such as jax.tree_map (I can help on this).

FYI @AdrienCorenflos already ported Empirical to numpyro in https://github.com/pyro-ppl/numpyro/issues/685