Open fehiepsi opened 4 years ago
More generally I think bags-of-samples are a more flexible data structure and could be used throughout Pyro & NumPyro. Marco & Vikash said they found bags-of-samples to be the right datatype to communicate between nested inference algorithms in Gen. And @eb8680 and I have found in Funsor that the output of Monte Carlo sampling is most cleanly represented by Delta distributions that can be weighted (and whose weights and values are both differentiable).
The idea of having a bag-of-samples data structure seems nice. Would you want to reuse something from funsors? Or Just create something simple for Pyro and NumPyro to use?
@fritzo Could you suggest a way to implement this? I am not familiar with "bag-of-samples"...
I figured he was referring to what is done in NumPyro, where data points are represented at Delta
distributions so that a sample can be represented naturally as a mixture distribution of these Deltas? (of course, correct me if I am wrong, just what I gathered from the Funsor paper: https://arxiv.org/abs/1910.10775).
Could you suggest a way to implement this?
Sorry I missed this. In many places we use sets of samples encoded as a batched tensor of shape (num_samples,) + event_shape
. You can think of this as encoding an implicitly uniformly-weighted bag of samples, where the weight of each sample is 1 / num_samples
. More generally you could create a data structure with both weights and samples. We have already implemented bags of samples of single sites as dist.Empirical in Pyro. You could do something similar with a vectorized trace or sample dict by adding a weights
or log_weights
tensor of shape (num_samples,)
, e.g. something like
class BagOfSamples:
def __init__(self, weights: np.ndarray, samples: dict):
assert len(weights.shape) == 1
for value in samples.values():
assert value.shape[:1] == weights.shape
self.weights = weights
self.samples = samples
Actually I think the Predictive
class could be generalized in this direction.
Thanks @fritzo and @Padarn! I got what you mean now.
@Padarn Do you want to take a stab on this? We can start with samples
is an array (which mimics Empirical) but it can be generalized easily to arbitrary pytree by using some jax utility such as jax.tree_map
(I can help on this).
FYI @AdrienCorenflos already ported Empirical to numpyro in https://github.com/pyro-ppl/numpyro/issues/685
Currently, summary returns ordinary statistics mean/var/quantiles/n_eff for MCMC samples. It would be nice to extend that utility to calculate weighted statistics. This would be useful for important sampling or nested sampling.
Proposal implementation: add a keyword
log_weights
(orweights
) tosummary
:If
log_weights
is not None, then we return weighted mean, weighted var, weighted quantiles (instead of hpdi?), effective sample size for weighted samples, and skipr_hat
.