Closed BenZickel closed 6 months ago
Check out this pull request on
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
Thx @fritzo for the review!
I also think heavy-tailed inference is much needed and I really appreciate all the work done on this so far.
It might be better to combine StableWithLogProb
and Stable
but I'd do it in a separate pull request (if at all). The advantage of keeping them separate is that users will be made explicitly aware of both the high cost of the log-probability calculation and the possibility of reducing that cost at the expense of accuracy by reparameterization. If we do combine the two we also need to figure out if the behavior of MinimalReparam needs to be modified when handling the Stable
distribution.
One more option that comes to mind is to keep both Stable
and StableWithLogProb
and add the .log_prob
method to Stable
. This way a user can enforce no reparameterization by using StableWithLogProb
instead of Stable
.
This fixes #3280 by adding
pyro.distributions.StableWithLogProb
which is based onpyro.distributions.Stable
with an additionallog_prob
method (I opted for not modifying thepyro.distributions.Stable
distribution at this stage).Code is based on combining https://github.com/pyro-ppl/pyro/issues/3280#issuecomment-1758461760 by @mawright with the existing Stable distribution Pyro code base, with the following modifications:
alpha
value of one, and values at and near zero.torchquad
package.torchquad
does this but overall speed is 25% faster than the reference implementation based ontorchquad
).Per iteration duration is about 5 times slower than with reparameterization but overall convergence is much faster, and includes cases which do not converge with reparameterization (like skew
beta
estimation).The log-probability calculation is based on integration over a uniformly distributed random variable $u$ such that $P(x) = \int du P(x|u) P(u)$. The integral can be converted to a reparameterization where we first sample $u$ with probability density $P(u)$ or $g(u)$ when approximating the posterior distribution by a guide, and secondly sampling or observing $x$ with the distribution $P(x|u)$. Initial tests indicate this reparameterization works but is still slower than estimating the log-probability by integration.
A usage example with real life data has been added to the last section of the Stable distribution tutorial.