Open quattro opened 3 years ago
@quattro I didn't use Truncated Gamma previously but I think this would be nice to have. FYI, we have TruncatedDistribution API with base distributions can be Cauchy, Laplace, Logistic, Normal, or StudentT. Probably it would be cleaner to follow that setting and have LeftTruncatedGamma, RightTruncatedGamma, TwoSidedTruncatedGamma implemented, then dispatching TruncatedGamma
to the corresponding class (based on low=None
or high=None
).
Great, thanks for the advice @fehiepsi. I'll likely be busy with some other things for a little while but plan on coming back to this soon.
Hi @fehiepsi , I made some progress on this, but stopped at the sampling implementation. I think there are two paths forward, and would appreciate your thoughts on the matter:
uniform
-> invCDF(gamma)
sampling based approachIn terms of implementation:
gammaincinv
in order to compute the quantiles of the gamma
distribution, which is currently not supported by JAX (as you pointed out here: https://github.com/google/jax/issues/5350). It looks like TFP has gammaincinv
implemented (along with gradients), but admittedly, I don't have the bandwidth to initiate a PR and port it over to JAX.I'm hoping to move forward with this soon, as it looks like there is other interest for a truncated gamma distribution, as mentioned in this thread: https://github.com/google/jax/issues/552 .
@quattro I think the simplest solution is to make a wrapper (with try/except import like contrib.tfp) for gammaincinv
in numpyro.distributions.util
and use it in the sample method. That would unblock you. What do you think?
@fehiepsi that works for now. If JAX ever natively supports gammaincinv
at a later date, I can come back to this to minimize core external dependencies.
Just an update. I've had this implemented for a while, but was not able to get a few tests to pass. I can initiate the PR to highlight which tests specifically, and perhaps get feedback. My guess is numerical precision being an issue due to the inverse CDF transform, but not 100% clear atm.
Sure, @quattro! We can discuss the numerical issue in detail in the PR. :)
It would be great to have a truncated gamma distribution implemented in NumPyro in order to cover lower-bounded variances (or upper bounded precisions) in probabilistic programs.
I'm happy to code this up and issue a pull-req after some testing.
A nice example would be in Mendelian Randomization from summary statistics (ie two-stage least squares). Residual variance greater than 1 is consistent with heterogeneity across studies, but inferred variance should be bounded below by 1.