Open dylanhmorris opened 2 years ago
I think it is a nice approach. We have those computations across various places. I think you can put those utilities in distributions/util.py file. The implementation looks reasonable to me. How about to discuss the details in your PR?
Sounds good. Will prepare one as soon as I have a chance.
When writing custom distributions, it is often helpful to have numerically stable implementations of
log_diff_exp(a, b) := log(exp(a) - exp(b))
and particularlylog1m_exp(x) := log(1 - exp(x))
. The naive implementations are not stable for many probabilistic programming use cases, and so probabilistic programming languages including Stan and PyMC provide numerically-stable implementations (typically following Machler, 2012) of these functions.As far as I can tell, Numpyro does not, and they are not present in Jax.
I wonder whether it would be worth providing them. I have written basic implementations following Machler for my own use. I would happily make a PR including them, but someone more experienced could probably write better/more idiomatic ones.