Closed LilyDong0127 closed 1 week ago
Thanks for the question! The reason for the discrepancy is that JAX does computations in float32 by default (see JAX Sharp Bits: double precision. If you enable 64-bit operations, then the JAX output matches the output of NumPy and other frameworks that compute in 64-bit by default:
import numpy as np
import jax
jax.config.update('jax_enable_x64', True)
x = np.array(0.9572367072105408, dtype='float64')
print("numpy:", np.arcsin(x))
print("jax: ", jnp.arcsin(x))
numpy: 1.2772947080197108
jax: 1.2772947080197108
and note also that with float32 inputs, NumPy's output matches the JAX output that you observed originally:
x = np.array(0.9572367072105408, dtype='float32')
print("numpy:", np.arcsin(x))
print("jax: ", jnp.arcsin(x))
numpy: 1.2772946
jax: 1.2772946
Thanks for the question! The reason for the discrepancy is that JAX does computations in float32 by default (see JAX Sharp Bits: double precision. If you enable 64-bit operations, then the JAX output matches the output of NumPy and other frameworks that compute in 64-bit by default:
import numpy as np import jax jax.config.update('jax_enable_x64', True) x = np.array(0.9572367072105408, dtype='float64') print("numpy:", np.arcsin(x)) print("jax: ", jnp.arcsin(x))
numpy: 1.2772947080197108 jax: 1.2772947080197108
and note also that with float32 inputs, NumPy's output matches the JAX output that you observed originally:
x = np.array(0.9572367072105408, dtype='float32') print("numpy:", np.arcsin(x)) print("jax: ", jnp.arcsin(x))
numpy: 1.2772946 jax: 1.2772946
Similarly, all of these are specified as float32, and there are some differences when compared with the results of np.
import torch import tensorflow as tf import jax.numpy as jnp import numpy as np from keras.layers import Lambda
input_data = np.array([ [ [ 0.5834500789642334, 0.05778983607888222 ], [ 0.13608911633491516, 0.8511932492256165 ], [ -0.8579278588294983, -0.8257414102554321 ], [ -0.9595631957054138, 0.665239691734314 ], [ 0.5563135147094727, 0.7400242686271667 ], [ 0.9572367072105408, 0.5983171463012695 ], [ -0.07704128324985504, 0.5610583424568176 ], [ -0.7634511590003967, 0.2798420488834381 ], [ -0.7132934331893921, 0.8893378376960754 ] ] ], dtype=np.float32) # Ensure input data is float32
def torch_arcsin(x): return torch.asin(torch.tensor(x, dtype=torch.float32))
def tf_arcsin(x): return tf.asin(tf.convert_to_tensor(x, dtype=tf.float32))
def keras_arcsin(x): return Lambda(lambda x: tf.math.asin(x))(tf.convert_to_tensor(x, dtype=tf.float32))
def jax_arcsin(x): return jnp.arcsin(jnp.array(x, dtype=np.float32)) # Ensure the input is float32
def chainer_arcsin(x): return np.arcsin(x.astype(np.float32)) # Ensure the input is float32
pytorch_result = torch_arcsin(input_data).detach().numpy() # Detach to convert to numpy tensorflow_result = tf_arcsin(input_data).numpy() keras_result = keras_arcsin(input_data).numpy() # Convert Keras result to numpy jax_result = jax_arcsin(input_data) chainer_result = chainer_arcsin(input_data)
print(f"PyTorch arcsin result: {pytorch_result}") print(f"TensorFlow arcsin result: {tensorflow_result}") print(f"Keras arcsin result: {keras_result}") print(f"JAX arcsin result: {jax_result}") print(f"Chainer arcsin result: {chainer_result}")
tolerance = 1e-7 # Set a tolerance for comparison results = { "PyTorch": pytorch_result, "TensorFlow": tensorflow_result, "Keras": keras_result, "JAX": jax_result, "Chainer": chainer_result }
for name, result in results.items(): diff = np.abs(results["PyTorch"] - result) print(f"Difference with {name}: {diff}")
passed = all(np.allclose(results["PyTorch"], result, atol=tolerance) for name, result in results.items() if name != "PyTorch") print(f"Tests passed: {passed}")
JAX arcsin result: [[[ 0.62297034 0.05782205] [ 0.13651271 1.0182546 ] [-1.0312228 -0.97151536] [-1.2854463 0.7278148 ] [ 0.5899429 0.8331064 ] [ 1.2772946 0.6413992 ] [-0.0771177 0.5956638 ] [-0.8686398 0.28362957] [-0.7941861 1.0958949 ]]] Chainer arcsin result: [[[ 0.62297034 0.05782205] [ 0.13651273 1.0182546 ] [-1.0312228 -0.9715154 ] [-1.2854464 0.7278148 ] [ 0.5899428 0.83310646] [ 1.2772946 0.6413992 ] [-0.0771177 0.5956638 ] [-0.8686398 0.28362957] [-0.7941861 1.095895 ]]]
Can you be more specific about what differences you're talking about? All I see here are two lists of numbers that look identical when I compare the first few by-eye. What should I be looking for?
Can you be more specific about what differences you're talking about? All I see here are two lists of numbers that look identical when I compare the first few by-eye. What should I be looking for?
Thanks for the prompt response. The differences are subtle but noticeable around the 7th decimal place. For example:
JAX arcsin result: [0.13651271, 1.0182546] Chainer arcsin result: [0.13651273, 1.0182545] As you can see, there's a slight difference in the 7th decimal place, which may seem small but can be significant for precision-critical applications.
What is "Chainer"?
I see – as in #24276, it appears the results differ by 1ULP, which is expected for different floating point implementations. For clarity, I'm going to close this as a duplicate of #24276, and we can continue discussing the issue there if you have further questions.
I see – as in #24276, it appears the results differ by 1ULP, which is expected for different floating point implementations. For clarity, I'm going to close this as a duplicate of #24276, and we can continue discussing the issue there if you have further questions.
Because chainer does not have its own native method, I use np. You can see the code example above. # JAX arcsin operation def jax_arcsin(x): return jnp.arcsin(jnp.array(x, dtype=np.float32)) # Ensure the input is float32
def chainer_arcsin(x): return np.arcsin(x.astype(np.float32)) # Ensure the input is float32
Thanks - I'd never heard of chainer before, but I found it through a web search.
Description
Problem Description There is a noticeable discrepancy in the results when using JAX for the arcsin function compared to other deep learning libraries such as PyTorch, TensorFlow, Keras, and Chainer. For certain input values, JAX yields results that are significantly different from those produced by the other frameworks, leading to concerns about consistency and accuracy.
Significant Differences The result for the input [1.2772946] is notably different between JAX and the other libraries, which may affect the accuracy of downstream tasks. Recommendation It is recommended to review the implementation of the arcsin function in JAX to ensure consistency and accuracy. Special attention should be given to how floating-point arithmetic and trigonometric functions are handled, as they can significantly influence results.
System info (python version, jaxlib version, accelerator, etc.)