jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.4k stars 2.79k forks source link

BUG: jnp.fft.ifft: inaccuracy on CUDA for 64-bit array #23827

Closed lucascolley closed 1 month ago

lucascolley commented 1 month ago

Description

With the following array as a:

``` Array([[1.00000000e+00+0.j, 5.77924896e-01+0.j, 1.11554120e-01+0.j, 7.19188336e-03+0.j, 1.54861144e-04+0.j, 1.11374328e-06+0.j, 2.67528799e-09+0.j, 2.14634134e-12+0.j, 2.14634134e-12+0.j, 2.67528799e-09+0.j, 1.11374328e-06+0.j, 1.54861144e-04+0.j, 7.19188336e-03+0.j, 1.11554120e-01+0.j, 5.77924896e-01+0.j], [5.98393334e-01+0.j, 3.45826405e-01+0.j, 6.67532419e-02+0.j, 4.30357506e-03+0.j, 9.26678760e-05+0.j, 6.66456555e-07+0.j, 1.60087450e-09+0.j, 1.28435635e-12+0.j, 1.28435635e-12+0.j, 1.60087450e-09+0.j, 6.66456555e-07+0.j, 9.26678760e-05+0.j, 4.30357506e-03+0.j, 6.67532419e-02+0.j, 3.45826405e-01+0.j], [1.28217406e-01+0.j, 7.41000311e-02+0.j, 1.43031799e-02+0.j, 9.22124628e-04+0.j, 1.98558941e-05+0.j, 1.42801275e-07+0.j, 3.43018487e-10+0.j, 2.75198319e-13+0.j, 2.75198319e-13+0.j, 3.43018487e-10+0.j, 1.42801275e-07+0.j, 1.98558941e-05+0.j, 9.22124628e-04+0.j, 1.43031799e-02+0.j, 7.41000311e-02+0.j], [9.83740881e-03+0.j, 5.68528347e-03+0.j, 1.09740349e-03+0.j, 7.07494967e-05+0.j, 1.52343238e-06+0.j, 1.09563480e-08+0.j, 2.63179017e-11+0.j, 2.11144372e-14+0.j, 2.11144372e-14+0.j, 2.63179017e-11+0.j, 1.09563480e-08+0.j, 1.52343238e-06+0.j, 7.07494967e-05+0.j, 1.09740349e-03+0.j, 5.68528347e-03+0.j], [2.70263842e-04+0.j, 1.56192203e-04+0.j, 3.01490451e-05+0.j, 1.94370603e-06+0.j, 4.18533676e-08+0.j, 3.01004538e-10+0.j, 7.23033610e-13+0.j, 5.80078457e-16+0.j, 5.80078457e-16+0.j, 7.23033610e-13+0.j, 3.01004538e-10+0.j, 4.18533676e-08+0.j, 1.94370603e-06+0.j, 3.01490451e-05+0.j, 1.56192203e-04+0.j], [2.65869590e-06+0.j, 1.53652655e-06+0.j, 2.96588482e-07+0.j, 1.91210308e-08+0.j, 4.11728687e-10+0.j, 2.96110469e-12+0.j, 7.11277721e-15+0.j, 5.70646892e-18+0.j, 5.70646892e-18+0.j, 7.11277721e-15+0.j, 2.96110469e-12+0.j, 4.11728687e-10+0.j, 1.91210308e-08+0.j, 2.96588482e-07+0.j, 1.53652655e-06+0.j], [9.36532554e-09+0.j, 5.41245480e-09+0.j, 1.04474065e-09+0.j, 6.73543289e-11+0.j, 1.45032502e-12+0.j, 1.04305684e-14+0.j, 2.50549430e-17+0.j, 2.01011854e-20+0.j, 2.01011854e-20+0.j, 2.50549430e-17+0.j, 1.04305684e-14+0.j, 1.45032502e-12+0.j, 6.73543289e-11+0.j, 1.04474065e-09+0.j, 5.41245480e-09+0.j], [1.18127383e-11+0.j, 6.82687559e-12+0.j, 1.31775963e-12+0.j, 8.49558363e-14+0.j, 1.82933417e-15+0.j, 1.31563580e-17+0.j, 3.16024770e-20+0.j, 2.53541687e-23+0.j, 2.53541687e-23+0.j, 3.16024770e-20+0.j, 1.31563580e-17+0.j, 1.82933417e-15+0.j, 8.49558363e-14+0.j, 1.31775963e-12+0.j, 6.82687559e-12+0.j], [5.33521326e-15+0.j, 3.08335257e-15+0.j, 5.95165021e-16+0.j, 3.83702314e-17+0.j, 8.26217227e-19+0.j, 5.94205792e-21+0.j, 1.42732320e-23+0.j, 1.14511888e-26+0.j, 1.14511888e-26+0.j, 1.42732320e-23+0.j, 5.94205792e-21+0.j, 8.26217227e-19+0.j, 3.83702314e-17+0.j, 5.95165021e-16+0.j, 3.08335257e-15+0.j], [8.62832462e-19+0.j, 4.98652362e-19+0.j, 9.62525163e-20+0.j, 6.20539043e-21+0.j, 1.33619222e-22+0.j, 9.60973858e-25+0.j, 2.30832532e-27+0.j, 1.85193298e-30+0.j, 1.85193298e-30+0.j, 2.30832532e-27+0.j, 9.60973858e-25+0.j, 1.33619222e-22+0.j, 6.20539043e-21+0.j, 9.62525163e-20+0.j, 4.98652362e-19+0.j], [0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j], [0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j], [0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j], [0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j], [0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j], [0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j, 0.00000000e+00+0.j]], dtype=complex128) ```

Generated via:

a = np.zeros((31, 15), dtype="float64")
a[0, 0] = 1.0
a = jnp.asarray(a)
a = jnp.fft.rfft(a, n=31, axis=0)
a = jnp.fft.fft(a, n=15, axis=1)
a = scipy.ndimage.fourier_gaussian(a, [5.0, 2.5], 31, 0) # XXX: with env var SCIPY_ARRAY_API=1
# agrees with the equivalent NumPy functions so far

The following call:

jnp.fft.ifft(a, n=31, axis=1)

Introduces inaccuracy compared to the equivalent in NumPy:

*** AssertionError: 
Not equal to tolerance rtol=0, atol=1.5e-14

Mismatched elements: 80 / 240 (33.3%)
Max absolute difference among violations: 8.32258473e-09
Max relative difference among violations: 5.21540689e-08

when running on CUDA. I can not reproduce the inaccuracy when either changing n=31 to n=32 or when running on CPU.

This is after calling jax.config.update("jax_enable_x64", True).


Is this inaccuracy expected? If so, it should be documented either at https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision or https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.ifft.html (or ideally, both).

Discovered in https://github.com/scipy/scipy/pull/21579#issuecomment-2366842932. cc @jakevdp

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28.dev20240710
numpy:  2.1.1
python: 3.12.5 | packaged by conda-forge | (main, Aug  8 2024, 18:36:51) [GCC 12.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='lucas-pc', release='6.8.0-45-generic', version='#45-Ubuntu SMP PREEMPT_DYNAMIC Fri Aug 30 12:02:04 UTC 2024', machine='x86_64')

$ nvidia-smi
Sun Sep 22 16:52:57 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce GTX 1060 6GB    Off | 00000000:08:00.0  On |                  N/A |
|  9%   52C    P2              28W / 120W |   5008MiB /  6144MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      1939      G   /usr/lib/xorg/Xorg                          138MiB |
|    0   N/A  N/A      2513      G   /usr/bin/gnome-shell                         52MiB |
|    0   N/A  N/A      3670      G   ...irefox/4955/usr/lib/firefox/firefox      131MiB |
|    0   N/A  N/A     15822      G   ...erProcess --variations-seed-version       40MiB |
|    0   N/A  N/A     25718      C   python                                     4640MiB |
+---------------------------------------------------------------------------------------+
unalmis commented 1 month ago

I have also found Jax's dct to be inaccurate compared to scipy: https://github.com/PlasmaControl/DESC/pull/1119/files#r1774234361, >= ~8 orders of magnitude even on 64 bit floating point mode.

To see this just change

# Uses Jax
fq_2 = norm * idct(dct(f(m), type=dct_type), n=n.size, type=dct_type)

to

# Uses scipy
fq_2 = norm * sidct(sdct(f(m), type=dct_type), n=n.size, type=dct_type)

and observe

# JAX is less accurate than scipy.
factor = 1e8 if using_jax else 1  # jax needs this buffer to pass
np.testing.assert_allclose(fq_2, f(n), atol=1e-14 * factor)
jakevdp commented 1 month ago

Hi, thanks for the report! Note that JAX does computations in 32-bits unless the jax_enable_x64 flag is explicitly set to True (see JAX sharp bits: 64 bit precision). If you're not enabling X64 mode, then it would be expected for JAX to have less accurate results than scipy, which does computations in 64 bit by default.

Additionally, even with X64=True, some GPU hardware is limited to 32-bit accuracy for some operations. In the case of operations related to matrix multiplication, you can adust this by configuring the jax_default_matmul_precision flag. That may be relevant in the case of FFT, but I'm not entirely sure.

unalmis commented 1 month ago

I think @lucascolley mentions:

This is after calling jax.config.update("jax_enable_x64", True).

jakevdp commented 1 month ago

Sorry, see the update to my answer

This is an example of why a minimal reproducible example is helpful for any error report or question. The code block in the report doesn't do anything with jax_enable_x64, and descriptions of code are rarely as clear as the code itself.

lucascolley commented 1 month ago

Apologies, I should have included the config in the code block.

Setting jax.config.update("jax_default_matmul_precision", 'float64'), the inaccuracy remains. Is that what I should be trying to set it to?

jakevdp commented 1 month ago

It looks like matmul is not involved here, sorry for the wrong suggestion.

This operation dispatches directly to the Hlo FFTOp. I suspect that on GPU XLA is executing this at float32 precision (many GPUs don't even have the capability to compute at float64 precision, except via multi-pass algorithms).

hawkinsp commented 1 month ago

Hmm. That's not the case here: we should be calling NVIDIA's 64-bit fft library.

dfm commented 1 month ago

Thanks for the report @lucascolley! I am able to reproduce the issue over here, and I think I've tracked it down to this line in XLA where we're losing precision. Let me see if I can land a fix!

lucascolley commented 1 month ago

fantastic news! JAX's efforts to support the standard are already paying off :)

lucascolley commented 1 month ago

thanks again! Will try removing the skips from our tests when a new JAX release is out.