I am working on a problem that involves modelling gaussian noise that is added to random samples from a gamma distribution.
Computing the corresponding "convoluted-gamma" likelihood involves calculating hyp1f1(a, b, x) (see eq. 7 of arXiv:0704.1706).
I compared two calculations of this convoluted-gamma pdf, using scipy and jax, and noticed that the former is stable for common choices (to our field) of parameters, while the latter is not.
Would it make sense to set this tolerance to a lower value?
Alternatively, it could be turned into an argument, adjustable by the user (and leave the default value at the current default).
I filed this as bug and not as feature request, because I naively expected similar stability to scipy's version.
Description
I am working on a problem that involves modelling gaussian noise that is added to random samples from a gamma distribution.
Computing the corresponding "convoluted-gamma" likelihood involves calculating hyp1f1(a, b, x) (see eq. 7 of arXiv:0704.1706).
I compared two calculations of this convoluted-gamma pdf, using scipy and jax, and noticed that the former is stable for common choices (to our field) of parameters, while the latter is not.
The jax computation appears well behaved, if I manually reduce the tolerance https://github.com/google/jax/blob/beaabae8f6d2ceb049840f59c6683e93a67726ae/jax/_src/scipy/special.py#L2426 that determines that truncation of the series expansion of hyp1f1 that is used in jax. For example: 1.e-8 -> 1.e-15.
See the attached jupyter notebook for details. gaussian_convoluted_gamma_pdf-5.pdf
Would it make sense to set this tolerance to a lower value? Alternatively, it could be turned into an argument, adjustable by the user (and leave the default value at the current default). I filed this as bug and not as feature request, because I naively expected similar stability to scipy's version.
I did notice that the gradients w.r.t to a and b, that are also computed via truncated series expansion, use lower tolerance too. https://github.com/google/jax/blob/ede94c3c81e1022af8bee3f60d6e5e0d65647c2a/jax/_src/scipy/special.py#L2478 https://github.com/google/jax/blob/ede94c3c81e1022af8bee3f60d6e5e0d65647c2a/jax/_src/scipy/special.py#L2504
There is also the asymptotic case of the hyp1f1 using less precision, which I didn't test yet. https://github.com/google/jax/blob/ede94c3c81e1022af8bee3f60d6e5e0d65647c2a/jax/_src/scipy/special.py#L2451
@renecotyfanboy : thanks for having added hyp1f1 to jax.
System info (python version, jaxlib version, accelerator, etc.)