google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.93k stars 2.74k forks source link

jax.scipy.special.hyp1f1 unstable where scipy.special.hyp1f1 is not #21503

Open HansN87 opened 3 months ago

HansN87 commented 3 months ago

Description

I am working on a problem that involves modelling gaussian noise that is added to random samples from a gamma distribution.

Computing the corresponding "convoluted-gamma" likelihood involves calculating hyp1f1(a, b, x) (see eq. 7 of arXiv:0704.1706).

I compared two calculations of this convoluted-gamma pdf, using scipy and jax, and noticed that the former is stable for common choices (to our field) of parameters, while the latter is not.

The jax computation appears well behaved, if I manually reduce the tolerance https://github.com/google/jax/blob/beaabae8f6d2ceb049840f59c6683e93a67726ae/jax/_src/scipy/special.py#L2426 that determines that truncation of the series expansion of hyp1f1 that is used in jax. For example: 1.e-8 -> 1.e-15.

See the attached jupyter notebook for details. gaussian_convoluted_gamma_pdf-5.pdf

Would it make sense to set this tolerance to a lower value? Alternatively, it could be turned into an argument, adjustable by the user (and leave the default value at the current default). I filed this as bug and not as feature request, because I naively expected similar stability to scipy's version.

I did notice that the gradients w.r.t to a and b, that are also computed via truncated series expansion, use lower tolerance too. https://github.com/google/jax/blob/ede94c3c81e1022af8bee3f60d6e5e0d65647c2a/jax/_src/scipy/special.py#L2478 https://github.com/google/jax/blob/ede94c3c81e1022af8bee3f60d6e5e0d65647c2a/jax/_src/scipy/special.py#L2504

There is also the asymptotic case of the hyp1f1 using less precision, which I didn't test yet. https://github.com/google/jax/blob/ede94c3c81e1022af8bee3f60d6e5e0d65647c2a/jax/_src/scipy/special.py#L2451

@renecotyfanboy : thanks for having added hyp1f1 to jax.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28
numpy:  1.26.4
python: 3.11.5 (main, Sep 22 2023, 17:02:10) [GCC 11.4.0]
jax.devices (2 total, 2 local): [cuda(id=0) cuda(id=1)]
process_count: 1
platform: uname_result(system='Linux', node='hanscomp', release='6.8.0-76060800daily20240311-generic', version='#202403110203~1715181801~22.04~aba43ee SMP PREEMPT_DYNAMIC Wed M', machine='x86_64')

$ nvidia-smi
Wed May 29 16:45:24 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.67                 Driver Version: 550.67         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:09:00.0 Off |                  Off |
| 30%   24C    P8             17W /  450W |   18711MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 3090        Off |   00000000:0A:00.0 Off |                  N/A |
|  0%   33C    P8             17W /  350W |     272MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      2579      G   /usr/lib/xorg/Xorg                            121MiB |
|    0   N/A  N/A      2692      G   /usr/bin/gnome-shell                           14MiB |
|    0   N/A  N/A     71703      C   .../py3_jax_latest-mr9UFGRS/bin/python      18558MiB |
|    1   N/A  N/A      2579      G   /usr/lib/xorg/Xorg                              4MiB |
|    1   N/A  N/A     71703      C   .../py3_jax_latest-mr9UFGRS/bin/python        256MiB |
+-----------------------------------------------------------------------------------------+
renecotyfanboy commented 3 months ago

I think this is a good way to deal with it, I'll look at drafting a PR.