Open LilyDong0127 opened 5 days ago
JAX arrays on CPU use FTZ mode, that is, subnormal numbers (like 1e-45) are flushed to zeros:
>>> with jax.default_device(jax.devices('cpu')[0]):
... jax.numpy.array(numpy.array([-0.0, 1.401298464324817e-45, -0.0], numpy.float32)) == 0
...
Array([ True, True, True], dtype=bool)
>>> with jax.default_device(jax.devices('cuda')[0]):
... jax.numpy.array(numpy.array([-0.0, 1.401298464324817e-45, -0.0], numpy.float32)) == 0
...
Array([ True, False, True], dtype=bool)
So, the issue is not in argsort
but in using FTZ mode in general on CPU.
JAX arrays on CPU use FTZ mode, that is, subnormal numbers (like 1e-45) are flushed to zeros:
>>> with jax.default_device(jax.devices('cpu')[0]): ... jax.numpy.array(numpy.array([-0.0, 1.401298464324817e-45, -0.0], numpy.float32)) == 0 ... Array([ True, True, True], dtype=bool) >>> with jax.default_device(jax.devices('cuda')[0]): ... jax.numpy.array(numpy.array([-0.0, 1.401298464324817e-45, -0.0], numpy.float32)) == 0 ... Array([ True, False, True], dtype=bool)
So, the issue is not in
argsort
but in using FTZ mode in general on CPU. Thank you for your response, but I'd like to emphasize that the main issue is with how JAX handles-0.0
inargsort
. 1. Handling of-0.0
inargsort
: According to the IEEE 754 standard,-0.0
and0.0
should be treated as equal. However, in JAX'sargsort
, it seems that-0.0
is treated differently from0.0
, leading to an incorrect sorting order. In my test case, JAX'sargsort
returns an index for-0.0
that suggests it's not equal to0.0
, which is not the expected behavior. In contrast, PyTorch correctly handles this case by treating-0.0
and0.0
as equal, resulting in the expected sorting order. 2. FTZ (Flush to Zero) mode on CPU: While the FTZ mode may explain the handling of subnormal numbers like1.401298464324817e-45
, the core issue in this particular case is the treatment of-0.0
. PyTorch, also running on the same CPU, does not exhibit this issue, suggesting that the problem is not inherent to the CPU but rather how JAX is handling-0.0
in its sorting operations. The incorrect handling of-0.0
is the root cause of the inconsistentargsort
results. Would it be possible to review how JAX is dealing with-0.0
inargsort
and ensure it conforms to the IEEE standard where-0.0
and0.0
are considered equal?
In processors that support the FTZ flag, enabling FTZ is optional. PyTorch obviously does not enable FTZ mode while JAX (read: some of its underlying component) does on CPU.
in JAX's argsort, it seems that -0.0 is treated differently from 0.0
Looking at your test output, I would conclude that -0.0
, 1e-45
, 0.0
are all treated as equal because in argsort
output, their relative order is unchanged (as per its stable=True
option).
In processors that support the FTZ flag, enabling FTZ is optional. PyTorch obviously does not enable FTZ mode while JAX (read: some of its underlying component) does on CPU.
in JAX's argsort, it seems that -0.0 is treated differently from 0.0
Looking at your test output, I would conclude that
-0.0
,1e-45
,0.0
are all treated as equal because inargsort
output, their relative order is unchanged (as per itsstable=True
option).
- CPU and GPU output should be consistent: According to the IEEE 754 standard, subnormal numbers (such as 1e-45) should not produce different results on the CPU and GPU. JAX uses the FTZ (Flush to Zero) mode on the CPU to flush these very small values to zero, but this mode is not enabled on the GPU, resulting in differences in sorting and comparison results.
However, users expect consistent results on all hardware platforms, especially in basic operations such as argsort. PyTorch is able to maintain consistent behavior on the CPU and GPU, indicating that this is a problem in the JAX implementation rather than a limitation of the hardware itself. Consistency on different hardware platforms is an important principle that numerical computing frameworks should follow.
In contrast, JAX's inconsistent behavior indicates that its handling of FTZ mode is not in compliance with the standard, especially when performing basic operations like argsort. This behavior leads to inconsistent results between CPU and GPU, which introduces potential bugs. JAX should behave consistently across all platforms, just like PyTorch does.
We confirmed in related issue (#24281) that JAX handles signed zeros consistently in operations like argsort
and argmax
.
With that out of the way, it seems your main concern is that JAX treats subnormal numbers differently depending on the backend. Is that correct?
We confirmed in related issue (#24281) that JAX handles signed zeros consistently in operations like
argsort
andargmax
.With that out of the way, it seems your main concern is that JAX treats subnormal numbers differently depending on the backend. Is that correct?
Yes, ~because I am currently testing deep learning libraries, I may submit some bugs to you in a short period of time. For our evaluation criteria, the output results of different deep learning libraries under the same conditions should be consistent if there is no problem. Secondly, when it comes to calculation accuracy issues, we believe that special cases need to be clearly noted and marked. For these cases, for example, pytorch can handle this problem well, so we believe that jax's performance should also be consistent and correct if there are no other conditions. Secondly, according to the CPU and GPU examples you provided, it is obvious that the output for the same set of data inputs shows differences, which I think meets the requirements of bugs.~
I edited your response to what would have been the most helpful.
FWIW, handling subnormals in a device-dependent way complicates testing on samples with subnormals. Here's another example of the issue reported here: https://github.com/pearu/functional_algorithms/issues/38#issuecomment-2366843504 where the results of evaluating math functions near branch cuts depend if subnormals are flushed or not.
Enabling FTZ is an optimization method and, imho, there should exist a method (jax.config/environment variable/...) that allows controlling the state of the FTZ flag by user programs or testing scripts.
cc/ @hawkinsp do you know whether it would be feasible to allow user-configurable FTZ semantics?
Description
Description: When using JAX's argsort function on an array containing small floating-point numbers, as well as 0.0 and -0.0, the sorting order is incorrect compared to other libraries like PyTorch.
Specifically, JAX incorrectly places the very small positive number 1.401298464324817e-45 before both 0.0 and -0.0. Expected behavior is that both 0.0 and -0.0 should be treated as equivalent and placed before any positive numbers, including very small values like 1.401298464324817e-45. PyTorch demonstrates the correct behavior in this case.
System info (python version, jaxlib version, accelerator, etc.)