Closed lucascolley closed 1 month ago
I have also found Jax's dct to be inaccurate compared to scipy: https://github.com/PlasmaControl/DESC/pull/1119/files#r1774234361, >= ~8 orders of magnitude even on 64 bit floating point mode.
To see this just change
# Uses Jax
fq_2 = norm * idct(dct(f(m), type=dct_type), n=n.size, type=dct_type)
to
# Uses scipy
fq_2 = norm * sidct(sdct(f(m), type=dct_type), n=n.size, type=dct_type)
and observe
# JAX is less accurate than scipy.
factor = 1e8 if using_jax else 1 # jax needs this buffer to pass
np.testing.assert_allclose(fq_2, f(n), atol=1e-14 * factor)
Hi, thanks for the report! Note that JAX does computations in 32-bits unless the jax_enable_x64
flag is explicitly set to True
(see JAX sharp bits: 64 bit precision). If you're not enabling X64 mode, then it would be expected for JAX to have less accurate results than scipy, which does computations in 64 bit by default.
Additionally, even with X64=True, some GPU hardware is limited to 32-bit accuracy for some operations. In the case of operations related to matrix multiplication, you can adust this by configuring the jax_default_matmul_precision
flag. That may be relevant in the case of FFT, but I'm not entirely sure.
I think @lucascolley mentions:
This is after calling jax.config.update("jax_enable_x64", True).
Sorry, see the update to my answer
This is an example of why a minimal reproducible example is helpful for any error report or question. The code block in the report doesn't do anything with jax_enable_x64
, and descriptions of code are rarely as clear as the code itself.
Apologies, I should have included the config in the code block.
Setting jax.config.update("jax_default_matmul_precision", 'float64')
, the inaccuracy remains. Is that what I should be trying to set it to?
It looks like matmul
is not involved here, sorry for the wrong suggestion.
This operation dispatches directly to the Hlo FFTOp. I suspect that on GPU XLA is executing this at float32 precision (many GPUs don't even have the capability to compute at float64 precision, except via multi-pass algorithms).
Hmm. That's not the case here: we should be calling NVIDIA's 64-bit fft library.
Thanks for the report @lucascolley! I am able to reproduce the issue over here, and I think I've tracked it down to this line in XLA where we're losing precision. Let me see if I can land a fix!
fantastic news! JAX's efforts to support the standard are already paying off :)
thanks again! Will try removing the skips from our tests when a new JAX release is out.
Description
With the following array as
a
:Generated via:
The following call:
Introduces inaccuracy compared to the equivalent in NumPy:
when running on CUDA. I can not reproduce the inaccuracy when either changing
n=31
ton=32
or when running on CPU.This is after calling
jax.config.update("jax_enable_x64", True)
.Is this inaccuracy expected? If so, it should be documented either at https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision or https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fft.ifft.html (or ideally, both).
Discovered in https://github.com/scipy/scipy/pull/21579#issuecomment-2366842932. cc @jakevdp
System info (python version, jaxlib version, accelerator, etc.)