pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.59k stars 987 forks source link

Add Stable distribution with numerically integrated log-probability calculation (StableWithLogProb). #3369

Closed BenZickel closed 6 months ago

BenZickel commented 6 months ago

This fixes #3280 by adding pyro.distributions.StableWithLogProb which is based on pyro.distributions.Stable with an additional log_prob method (I opted for not modifying the pyro.distributions.Stable distribution at this stage).

Code is based on combining https://github.com/pyro-ppl/pyro/issues/3280#issuecomment-1758461760 by @mawright with the existing Stable distribution Pyro code base, with the following modifications:

Per iteration duration is about 5 times slower than with reparameterization but overall convergence is much faster, and includes cases which do not converge with reparameterization (like skew beta estimation).

The log-probability calculation is based on integration over a uniformly distributed random variable $u$ such that $P(x) = \int du P(x|u) P(u)$. The integral can be converted to a reparameterization where we first sample $u$ with probability density $P(u)$ or $g(u)$ when approximating the posterior distribution by a guide, and secondly sampling or observing $x$ with the distribution $P(x|u)$. Initial tests indicate this reparameterization works but is still slower than estimating the log-probability by integration.

A usage example with real life data has been added to the last section of the Stable distribution tutorial.

review-notebook-app[bot] commented 6 months ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

BenZickel commented 6 months ago

Thx @fritzo for the review!

I also think heavy-tailed inference is much needed and I really appreciate all the work done on this so far.

It might be better to combine StableWithLogProb and Stable but I'd do it in a separate pull request (if at all). The advantage of keeping them separate is that users will be made explicitly aware of both the high cost of the log-probability calculation and the possibility of reducing that cost at the expense of accuracy by reparameterization. If we do combine the two we also need to figure out if the behavior of MinimalReparam needs to be modified when handling the Stable distribution.

BenZickel commented 6 months ago

One more option that comes to mind is to keep both Stable and StableWithLogProb and add the .log_prob method to Stable. This way a user can enforce no reparameterization by using StableWithLogProb instead of Stable.