jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.26k stars 2.77k forks source link

argsort incorrectly handles very small floating-point numbers and -0.0 compared to PyTorch #24280

Open LilyDong0127 opened 5 days ago

LilyDong0127 commented 5 days ago

Description

Description: When using JAX's argsort function on an array containing small floating-point numbers, as well as 0.0 and -0.0, the sorting order is incorrect compared to other libraries like PyTorch.

Specifically, JAX incorrectly places the very small positive number 1.401298464324817e-45 before both 0.0 and -0.0. Expected behavior is that both 0.0 and -0.0 should be treated as equivalent and placed before any positive numbers, including very small values like 1.401298464324817e-45. PyTorch demonstrates the correct behavior in this case.

import numpy as np
import torch
import tensorflow as tf
import jax.numpy as jnp

def test_argsort():
    # Input data, hardcoded as float32
    input_data = np.array([
        -0.0, 1.401298464324817e-45, 1.100000023841858, -0.0,
        5.960464477539063e-08, -2.0000100135803223, 1000000.0,
        722801.375, 0.0, -1.100000023841858
    ], dtype=np.float32)

    # PyTorch argsort
    pytorch_result = torch.argsort(torch.tensor(input_data, dtype=torch.float32)).numpy()
    print(f"PyTorch argsort result: {pytorch_result}")

    # TensorFlow argsort
    tensorflow_result = tf.argsort(input_data).numpy().astype(np.int32)
    print(f"TensorFlow argsort result: {tensorflow_result}")

    # JAX argsort
    jax_result = jnp.argsort(input_data).astype(np.int32)
    print(f"JAX argsort result: {jax_result}")

if __name__ == "__main__":
    test_argsort()
PyTorch argsort result: [5 9 0 3 8 1 4 2 7 6]
TensorFlow argsort result: [5 9 0 1 3 8 4 2 7 6]
JAX argsort result: [5 9 0 1 3 8 4 2 7 6]

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.9.19 | packaged by conda-forge | (main, Mar 20 2024, 12:38:46) [MSC v.1929 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', node='Lily的电脑', release='10', version='10.0.22631', machine='AMD64')
pearu commented 5 days ago

JAX arrays on CPU use FTZ mode, that is, subnormal numbers (like 1e-45) are flushed to zeros:

>>> with jax.default_device(jax.devices('cpu')[0]):
...   jax.numpy.array(numpy.array([-0.0, 1.401298464324817e-45, -0.0], numpy.float32)) == 0
... 
Array([ True,  True,  True], dtype=bool)
>>> with jax.default_device(jax.devices('cuda')[0]):
...   jax.numpy.array(numpy.array([-0.0, 1.401298464324817e-45, -0.0], numpy.float32)) == 0
... 
Array([ True, False,  True], dtype=bool)

So, the issue is not in argsort but in using FTZ mode in general on CPU.

LilyDong0127 commented 5 days ago

JAX arrays on CPU use FTZ mode, that is, subnormal numbers (like 1e-45) are flushed to zeros:

>>> with jax.default_device(jax.devices('cpu')[0]):
...   jax.numpy.array(numpy.array([-0.0, 1.401298464324817e-45, -0.0], numpy.float32)) == 0
... 
Array([ True,  True,  True], dtype=bool)
>>> with jax.default_device(jax.devices('cuda')[0]):
...   jax.numpy.array(numpy.array([-0.0, 1.401298464324817e-45, -0.0], numpy.float32)) == 0
... 
Array([ True, False,  True], dtype=bool)

So, the issue is not in argsort but in using FTZ mode in general on CPU. Thank you for your response, but I'd like to emphasize that the main issue is with how JAX handles -0.0 in argsort. 1. Handling of -0.0 in argsort: According to the IEEE 754 standard, -0.0 and 0.0 should be treated as equal. However, in JAX's argsort, it seems that -0.0 is treated differently from 0.0, leading to an incorrect sorting order. In my test case, JAX's argsort returns an index for -0.0 that suggests it's not equal to 0.0, which is not the expected behavior. In contrast, PyTorch correctly handles this case by treating -0.0 and 0.0 as equal, resulting in the expected sorting order. 2. FTZ (Flush to Zero) mode on CPU: While the FTZ mode may explain the handling of subnormal numbers like 1.401298464324817e-45, the core issue in this particular case is the treatment of -0.0. PyTorch, also running on the same CPU, does not exhibit this issue, suggesting that the problem is not inherent to the CPU but rather how JAX is handling -0.0 in its sorting operations. The incorrect handling of -0.0 is the root cause of the inconsistent argsort results. Would it be possible to review how JAX is dealing with -0.0 in argsort and ensure it conforms to the IEEE standard where -0.0 and 0.0 are considered equal?

pearu commented 5 days ago

In processors that support the FTZ flag, enabling FTZ is optional. PyTorch obviously does not enable FTZ mode while JAX (read: some of its underlying component) does on CPU.

in JAX's argsort, it seems that -0.0 is treated differently from 0.0

Looking at your test output, I would conclude that -0.0, 1e-45, 0.0 are all treated as equal because in argsort output, their relative order is unchanged (as per its stable=True option).

LilyDong0127 commented 5 days ago

In processors that support the FTZ flag, enabling FTZ is optional. PyTorch obviously does not enable FTZ mode while JAX (read: some of its underlying component) does on CPU.

in JAX's argsort, it seems that -0.0 is treated differently from 0.0

Looking at your test output, I would conclude that -0.0, 1e-45, 0.0 are all treated as equal because in argsort output, their relative order is unchanged (as per its stable=True option).

  1. CPU and GPU output should be consistent: According to the IEEE 754 standard, subnormal numbers (such as 1e-45) should not produce different results on the CPU and GPU. JAX uses the FTZ (Flush to Zero) mode on the CPU to flush these very small values ​​to zero, but this mode is not enabled on the GPU, resulting in differences in sorting and comparison results.

However, users expect consistent results on all hardware platforms, especially in basic operations such as argsort. PyTorch is able to maintain consistent behavior on the CPU and GPU, indicating that this is a problem in the JAX implementation rather than a limitation of the hardware itself. Consistency on different hardware platforms is an important principle that numerical computing frameworks should follow.

  1. PyTorch handles it correctly: In PyTorch, the results of subnormal number processing are consistent regardless of CPU or GPU. In particular, in this case, PyTorch returns the correct sort result because it does not change the processing of these small numbers due to different platforms.

In contrast, JAX's inconsistent behavior indicates that its handling of FTZ mode is not in compliance with the standard, especially when performing basic operations like argsort. This behavior leads to inconsistent results between CPU and GPU, which introduces potential bugs. JAX should behave consistently across all platforms, just like PyTorch does.

jakevdp commented 5 days ago

We confirmed in related issue (#24281) that JAX handles signed zeros consistently in operations like argsort and argmax.

With that out of the way, it seems your main concern is that JAX treats subnormal numbers differently depending on the backend. Is that correct?

LilyDong0127 commented 5 days ago

We confirmed in related issue (#24281) that JAX handles signed zeros consistently in operations like argsort and argmax.

With that out of the way, it seems your main concern is that JAX treats subnormal numbers differently depending on the backend. Is that correct?

Yes, ~because I am currently testing deep learning libraries, I may submit some bugs to you in a short period of time. For our evaluation criteria, the output results of different deep learning libraries under the same conditions should be consistent if there is no problem. Secondly, when it comes to calculation accuracy issues, we believe that special cases need to be clearly noted and marked. For these cases, for example, pytorch can handle this problem well, so we believe that jax's performance should also be consistent and correct if there are no other conditions. Secondly, according to the CPU and GPU examples you provided, it is obvious that the output for the same set of data inputs shows differences, which I think meets the requirements of bugs.~

jakevdp commented 5 days ago

I edited your response to what would have been the most helpful.

pearu commented 5 days ago

FWIW, handling subnormals in a device-dependent way complicates testing on samples with subnormals. Here's another example of the issue reported here: https://github.com/pearu/functional_algorithms/issues/38#issuecomment-2366843504 where the results of evaluating math functions near branch cuts depend if subnormals are flushed or not.

Enabling FTZ is an optimization method and, imho, there should exist a method (jax.config/environment variable/...) that allows controlling the state of the FTZ flag by user programs or testing scripts.

jakevdp commented 5 days ago

cc/ @hawkinsp do you know whether it would be feasible to allow user-configurable FTZ semantics?