jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.3k stars 2.78k forks source link

Arcsin Error Report #24275

Closed LilyDong0127 closed 1 week ago

LilyDong0127 commented 1 week ago

Description

Problem Description There is a noticeable discrepancy in the results when using JAX for the arcsin function compared to other deep learning libraries such as PyTorch, TensorFlow, Keras, and Chainer. For certain input values, JAX yields results that are significantly different from those produced by the other frameworks, leading to concerns about consistency and accuracy.

import torch
import tensorflow as tf
import jax.numpy as jnp
import numpy as np
from keras.layers import Lambda

# Input data (as provided)
input_data = np.array([
    [
        [
            0.5834500789642334,
            0.05778983607888222
        ],
        [
            0.13608911633491516,
            0.8511932492256165
        ],
        [
            -0.8579278588294983,
            -0.8257414102554321
        ],
        [
            -0.9595631957054138,
            0.665239691734314
        ],
        [
            0.5563135147094727,
            0.7400242686271667
        ],
        [
            0.9572367072105408,
            0.5983171463012695
        ],
        [
            -0.07704128324985504,
            0.5610583424568176
        ],
        [
            -0.7634511590003967,
            0.2798420488834381
        ],
        [
            -0.7132934331893921,
            0.8893378376960754
        ]
    ]
])

# PyTorch arcsin operation
def torch_arcsin(x):
    return torch.asin(torch.tensor(x, dtype=torch.float32))

# TensorFlow arcsin operation
def tf_arcsin(x):
    return tf.asin(tf.convert_to_tensor(x, dtype=tf.float32))

# Keras arcsin operation
def keras_arcsin(x):
    return Lambda(lambda x: tf.math.asin(x))(tf.convert_to_tensor(x, dtype=tf.float32))

# JAX arcsin operation
def jax_arcsin(x):
    return jnp.arcsin(jnp.array(x))

# Chainer arcsin operation
def chainer_arcsin(x):
    return np.arcsin(x)

# Calculate results
pytorch_result = torch_arcsin(input_data).detach().numpy()  # Detach to convert to numpy
tensorflow_result = tf_arcsin(input_data).numpy()
keras_result = keras_arcsin(input_data).numpy()  # Convert Keras result to numpy
jax_result = jax_arcsin(input_data)
chainer_result = chainer_arcsin(input_data)

# Print results
print(f"PyTorch arcsin result: {pytorch_result}")
print(f"TensorFlow arcsin result: {tensorflow_result}")
print(f"Keras arcsin result: {keras_result}")
print(f"JAX arcsin result: {jax_result}")
print(f"Chainer arcsin result: {chainer_result}")

# Compare results
tolerance = 1e-7  # Set a tolerance for comparison
results = {
    "PyTorch": pytorch_result,
    "TensorFlow": tensorflow_result,
    "Keras": keras_result,
    "JAX": jax_result,
    "Chainer": chainer_result
}

for name, result in results.items():
    diff = np.abs(results["PyTorch"] - result)
    print(f"Difference with {name}: {diff}")

# Check for passing criteria
passed = all(np.allclose(results["PyTorch"], result, atol=tolerance) for name, result in results.items() if name != "PyTorch")
print(f"Tests passed: {passed}")
PyTorch arcsin result: [[[ 0.62297034  0.05782205]
  [ 0.13651273  1.0182546 ]
  [-1.0312228  -0.97151536]
  [-1.2854464   0.7278148 ]
  [ 0.5899428   0.83310646]
  [ 1.2772948   0.6413992 ]
  [-0.0771177   0.5956638 ]
  [-0.8686398   0.28362957]
  [-0.7941861   1.0958949 ]]]
TensorFlow arcsin result: [[[ 0.62297034  0.05782206]
  [ 0.13651271  1.0182548 ]
  [-1.0312228  -0.97151536]
  [-1.2854464   0.727815  ]
  [ 0.5899428   0.83310646]
  [ 1.2772948   0.64139926]
  [-0.0771177   0.5956638 ]
  [-0.8686399   0.2836296 ]
  [-0.7941861   1.095895  ]]]
Keras arcsin result: [[[ 0.62297034  0.05782206]
  [ 0.13651271  1.0182548 ]
  [-1.0312228  -0.97151536]
  [-1.2854464   0.727815  ]
  [ 0.5899428   0.83310646]
  [ 1.2772948   0.64139926]
  [-0.0771177   0.5956638 ]
  [-0.8686399   0.2836296 ]
  [-0.7941861   1.095895  ]]]
JAX arcsin result: [[[ 0.62297034  0.05782205]
  [ 0.13651271  1.0182546 ]
  [-1.0312228  -0.97151536]
  [-1.2854463   0.7278148 ]
  [ 0.5899429   0.8331064 ]
  [ 1.2772946   0.6413992 ]
  [-0.0771177   0.5956638 ]
  [-0.8686398   0.28362957]
  [-0.7941861   1.0958949 ]]]
Chainer arcsin result: [[[ 0.62297033  0.05782205]
  [ 0.13651272  1.01825461]
  [-1.03122278 -0.97151538]
  [-1.28544635  0.7278148 ]
  [ 0.58994283  0.83310644]
  [ 1.27729471  0.6413992 ]
  [-0.0771177   0.59566378]
  [-0.86863983  0.28362958]
  [-0.79418614  1.09589499]]]

Significant Differences The result for the input [1.2772946] is notably different between JAX and the other libraries, which may affect the accuracy of downstream tasks. Recommendation It is recommended to review the implementation of the arcsin function in JAX to ensure consistency and accuracy. Special attention should be given to how floating-point arithmetic and trigonometric functions are handled, as they can significantly influence results.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.9.19 | packaged by conda-forge | (main, Mar 20 2024, 12:38:46) [MSC v.1929 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', node='Lily的电脑', release='10', version='10.0.22631', machine='AMD64')
jakevdp commented 1 week ago

Thanks for the question! The reason for the discrepancy is that JAX does computations in float32 by default (see JAX Sharp Bits: double precision. If you enable 64-bit operations, then the JAX output matches the output of NumPy and other frameworks that compute in 64-bit by default:

import numpy as np
import jax
jax.config.update('jax_enable_x64', True)

x = np.array(0.9572367072105408, dtype='float64')

print("numpy:", np.arcsin(x))
print("jax:  ", jnp.arcsin(x))
numpy: 1.2772947080197108
jax:   1.2772947080197108

and note also that with float32 inputs, NumPy's output matches the JAX output that you observed originally:

x = np.array(0.9572367072105408, dtype='float32')

print("numpy:", np.arcsin(x))
print("jax:  ", jnp.arcsin(x))
numpy: 1.2772946
jax:   1.2772946
LilyDong0127 commented 1 week ago

Thanks for the question! The reason for the discrepancy is that JAX does computations in float32 by default (see JAX Sharp Bits: double precision. If you enable 64-bit operations, then the JAX output matches the output of NumPy and other frameworks that compute in 64-bit by default:

import numpy as np
import jax
jax.config.update('jax_enable_x64', True)

x = np.array(0.9572367072105408, dtype='float64')

print("numpy:", np.arcsin(x))
print("jax:  ", jnp.arcsin(x))
numpy: 1.2772947080197108
jax:   1.2772947080197108

and note also that with float32 inputs, NumPy's output matches the JAX output that you observed originally:

x = np.array(0.9572367072105408, dtype='float32')

print("numpy:", np.arcsin(x))
print("jax:  ", jnp.arcsin(x))
numpy: 1.2772946
jax:   1.2772946

Similarly, all of these are specified as float32, and there are some differences when compared with the results of np.


import torch
import tensorflow as tf
import jax.numpy as jnp
import numpy as np
from keras.layers import Lambda

Input data (as provided) - all values are set to float32

input_data = np.array([ [ [ 0.5834500789642334, 0.05778983607888222 ], [ 0.13608911633491516, 0.8511932492256165 ], [ -0.8579278588294983, -0.8257414102554321 ], [ -0.9595631957054138, 0.665239691734314 ], [ 0.5563135147094727, 0.7400242686271667 ], [ 0.9572367072105408, 0.5983171463012695 ], [ -0.07704128324985504, 0.5610583424568176 ], [ -0.7634511590003967, 0.2798420488834381 ], [ -0.7132934331893921, 0.8893378376960754 ] ] ], dtype=np.float32) # Ensure input data is float32

PyTorch arcsin operation

def torch_arcsin(x): return torch.asin(torch.tensor(x, dtype=torch.float32))

TensorFlow arcsin operation

def tf_arcsin(x): return tf.asin(tf.convert_to_tensor(x, dtype=tf.float32))

Keras arcsin operation

def keras_arcsin(x): return Lambda(lambda x: tf.math.asin(x))(tf.convert_to_tensor(x, dtype=tf.float32))

JAX arcsin operation

def jax_arcsin(x): return jnp.arcsin(jnp.array(x, dtype=np.float32)) # Ensure the input is float32

Chainer arcsin operation

def chainer_arcsin(x): return np.arcsin(x.astype(np.float32)) # Ensure the input is float32

Calculate results

pytorch_result = torch_arcsin(input_data).detach().numpy() # Detach to convert to numpy tensorflow_result = tf_arcsin(input_data).numpy() keras_result = keras_arcsin(input_data).numpy() # Convert Keras result to numpy jax_result = jax_arcsin(input_data) chainer_result = chainer_arcsin(input_data)

Print results

print(f"PyTorch arcsin result: {pytorch_result}") print(f"TensorFlow arcsin result: {tensorflow_result}") print(f"Keras arcsin result: {keras_result}") print(f"JAX arcsin result: {jax_result}") print(f"Chainer arcsin result: {chainer_result}")

Compare results

tolerance = 1e-7 # Set a tolerance for comparison results = { "PyTorch": pytorch_result, "TensorFlow": tensorflow_result, "Keras": keras_result, "JAX": jax_result, "Chainer": chainer_result }

for name, result in results.items(): diff = np.abs(results["PyTorch"] - result) print(f"Difference with {name}: {diff}")

Check for passing criteria

passed = all(np.allclose(results["PyTorch"], result, atol=tolerance) for name, result in results.items() if name != "PyTorch") print(f"Tests passed: {passed}")

JAX arcsin result: [[[ 0.62297034 0.05782205] [ 0.13651271 1.0182546 ] [-1.0312228 -0.97151536] [-1.2854463 0.7278148 ] [ 0.5899429 0.8331064 ] [ 1.2772946 0.6413992 ] [-0.0771177 0.5956638 ] [-0.8686398 0.28362957] [-0.7941861 1.0958949 ]]] Chainer arcsin result: [[[ 0.62297034 0.05782205] [ 0.13651273 1.0182546 ] [-1.0312228 -0.9715154 ] [-1.2854464 0.7278148 ] [ 0.5899428 0.83310646] [ 1.2772946 0.6413992 ] [-0.0771177 0.5956638 ] [-0.8686398 0.28362957] [-0.7941861 1.095895 ]]]

jakevdp commented 1 week ago

Can you be more specific about what differences you're talking about? All I see here are two lists of numbers that look identical when I compare the first few by-eye. What should I be looking for?

LilyDong0127 commented 1 week ago

Can you be more specific about what differences you're talking about? All I see here are two lists of numbers that look identical when I compare the first few by-eye. What should I be looking for?

Thanks for the prompt response. The differences are subtle but noticeable around the 7th decimal place. For example:

JAX arcsin result: [0.13651271, 1.0182546] Chainer arcsin result: [0.13651273, 1.0182545] As you can see, there's a slight difference in the 7th decimal place, which may seem small but can be significant for precision-critical applications.

jakevdp commented 1 week ago

What is "Chainer"?

jakevdp commented 1 week ago

I see – as in #24276, it appears the results differ by 1ULP, which is expected for different floating point implementations. For clarity, I'm going to close this as a duplicate of #24276, and we can continue discussing the issue there if you have further questions.

LilyDong0127 commented 1 week ago

I see – as in #24276, it appears the results differ by 1ULP, which is expected for different floating point implementations. For clarity, I'm going to close this as a duplicate of #24276, and we can continue discussing the issue there if you have further questions.

Because chainer does not have its own native method, I use np. You can see the code example above. # JAX arcsin operation def jax_arcsin(x): return jnp.arcsin(jnp.array(x, dtype=np.float32)) # Ensure the input is float32

Chainer arcsin operation

def chainer_arcsin(x): return np.arcsin(x.astype(np.float32)) # Ensure the input is float32

jakevdp commented 1 week ago

Thanks - I'd never heard of chainer before, but I found it through a web search.