pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.18k stars 239 forks source link

Truncated Gamma #969

Open quattro opened 3 years ago

quattro commented 3 years ago

It would be great to have a truncated gamma distribution implemented in NumPyro in order to cover lower-bounded variances (or upper bounded precisions) in probabilistic programs.

I'm happy to code this up and issue a pull-req after some testing.

A nice example would be in Mendelian Randomization from summary statistics (ie two-stage least squares). Residual variance greater than 1 is consistent with heterogeneity across studies, but inferred variance should be bounded below by 1.

fehiepsi commented 3 years ago

@quattro I didn't use Truncated Gamma previously but I think this would be nice to have. FYI, we have TruncatedDistribution API with base distributions can be Cauchy, Laplace, Logistic, Normal, or StudentT. Probably it would be cleaner to follow that setting and have LeftTruncatedGamma, RightTruncatedGamma, TwoSidedTruncatedGamma implemented, then dispatching TruncatedGamma to the corresponding class (based on low=None or high=None).

quattro commented 3 years ago

Great, thanks for the advice @fehiepsi. I'll likely be busy with some other things for a little while but plan on coming back to this soon.

quattro commented 3 years ago

Hi @fehiepsi , I made some progress on this, but stopped at the sampling implementation. I think there are two paths forward, and would appreciate your thoughts on the matter:

  1. Use a uniform -> invCDF(gamma) sampling based approach
  2. Use rejection sampling

In terms of implementation:

  1. Requires gammaincinv in order to compute the quantiles of the gamma distribution, which is currently not supported by JAX (as you pointed out here: https://github.com/google/jax/issues/5350). It looks like TFP has gammaincinv implemented (along with gradients), but admittedly, I don't have the bandwidth to initiate a PR and port it over to JAX.
  2. There are a few papers describing sampling approaches to either the Left/Right truncated Gamma, or the more general case using latent variables, that might be worthwhile investigating.

I'm hoping to move forward with this soon, as it looks like there is other interest for a truncated gamma distribution, as mentioned in this thread: https://github.com/google/jax/issues/552 .

fehiepsi commented 3 years ago

@quattro I think the simplest solution is to make a wrapper (with try/except import like contrib.tfp) for gammaincinv in numpyro.distributions.util and use it in the sample method. That would unblock you. What do you think?

quattro commented 3 years ago

@fehiepsi that works for now. If JAX ever natively supports gammaincinv at a later date, I can come back to this to minimize core external dependencies.

quattro commented 3 years ago

Just an update. I've had this implemented for a while, but was not able to get a few tests to pass. I can initiate the PR to highlight which tests specifically, and perhaps get feedback. My guess is numerical precision being an issue due to the inverse CDF transform, but not 100% clear atm.

fehiepsi commented 3 years ago

Sure, @quattro! We can discuss the numerical issue in detail in the PR. :)