pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.58k stars 987 forks source link

Resampler for weighed samples #3352

Closed BenZickel closed 7 months ago

BenZickel commented 7 months ago

Background

This pull request introduces a resampler for weighed samples that creates equally weighed samples from the distribution specified by the generator of the weighed samples (see item 3 here).

Implementation

Notes

fritzo commented 7 months ago

Hi @BenZickel, thanks for your patience. I'm not sure I understand how MHResampler is intended to be used. Do you intend this new resampler to be used as a component in a larger training algorithm, such that the whole algorithm will be akin to Reweighted Wake Sleep? Or would the resampler be mainly used for prediction? Gosh in either case it might help to have example usage in the docstring of MHResampler.

Also, do you understand the relationship betwen your MHResampler and the importance_resample utility we discussed in your PR from a couple weeks ago?

@fritzo: ...it might be nice to add some sort of utility importance_resample: WeightedSamples -> Samples to convert from the weighted representation back to an unweighted representation... @BenZickel: I agree that we need to have some way to convert weighed samples to unweighed samples, but I believe this should be added in another pull request as there are several considerations related to multi-dimensional event samples and interpolation methods...

My thinking then was that it would be nice to bridge the two worlds: weighted versus uniform samples. I figured a simple way to convert weighted -> unweighted samples would be to add a method WeightedPredictiveResult.resample() that just called _systematic_resample under the hood, and converted types. I'm not sure this .resample() method has anything to do with your current PR other than the word 'resample' 😄

BenZickel commented 7 months ago

Thx @fritzo for the review! The intended use of MHResampler is mainly for prediction and I've added an example that reflects that (the example is basically copied from the combined test for MHResampler and WeighedPredictive). Although MHResampler is mainly used for prediction, it actually creates posterior predictive samples that are independent of the guide, and therefore could produce accurate results with fewer SVI iterations and reduced overall running time (this is actually tested in this test configuration where SVI iteration count is reduced from 5000 to 1000).

Regarding your second point, MHResampler is a way to do importance_resample, but not as a method .resample of WeightedPredictiveResult. The reason is that resampling from a fixed set of samples (by a method .resample for example) creates correlation and high variance of computed quantities, whereas MHResampler continuously creates new samples and selects which current samples will be replaced by the new samples. Due to this MHResampler needs access to the callable that creates new WeightedPredictiveResult when called (usually an instance of WeighedPredictive but other callables will work as well).

Lastly, you mentioned _systematic_resample which does resampling from a fixed set of samples (which is not what we want as explained in the previous paragraph). Resampling from a fixed set of samples is a necessity in Sequential Monte Carlo methods as the samples are time (sequence) dependent and therefore new samples for the current time cannot be generated without starting from time zero.