Open LilyDong0127 opened 1 month ago
I suspect this is the same thing seen in #24275: JAX is computing in 32-bit, and you're comparing to platforms that compute in 64-bit.
我怀疑这与#24275中看到的是同一件事:JAX 以 32 位计算,而您正在与以 64 位计算的平台进行比较。 But please look at my latest code, I use float32, but the final result of jax is obviously different from the results of other deep learning libraries? The results of the other three deep learning libraries are the same.
import torch import tensorflow as tf import jax.numpy as jnp import numpy as np from keras.layers import Lambda
input_data = np.array([ [ [5.834500789642334, 0.5778982043266296], [1.360891342163086, 8.511932373046875], [-8.579278945922852, -8.257413864135742], [-9.59563159942627, 6.6523966789245605], [5.563135147094727, 7.400242805480957], [9.572366714477539, 5.983171463012695], [-0.7704126238822937, 5.610583305358887], [-7.634511470794678, 2.7984204292297363], [-7.132934093475342, 8.893378257751465] ] ], dtype=np.float32) # Ensure input data is float32
def torch_arcsinh(x): return torch.asinh(torch.tensor(x, dtype=torch.float32))
def tf_arcsinh(x): return tf.asinh(tf.convert_to_tensor(x, dtype=tf.float32))
def keras_arcsinh(x): return Lambda(lambda x: tf.math.asinh(x))(tf.convert_to_tensor(x, dtype=tf.float32))
def jax_arcsinh(x): return jnp.arcsinh(jnp.array(x, dtype=np.float32)) # Set dtype to float32
pytorch_result = torch_arcsinh(input_data) tensorflow_result = tf_arcsinh(input_data) keras_result = keras_arcsinh(input_data).numpy() # Convert Keras result to numpy jax_result = jax_arcsinh(input_data)
print(f"PyTorch arcsinh result: {pytorch_result.detach().numpy()}") # Detach to convert to numpy print(f"TensorFlow arcsinh result: {tensorflow_result.numpy()}") print(f"Keras arcsinh result: {keras_result}") print(f"JAX arcsinh result: {jax_result}")
tolerance = 1e-7 # Set a tolerance for comparison results = { "PyTorch": pytorch_result.detach().numpy(), "TensorFlow": tensorflow_result.numpy(), "Keras": keras_result, "JAX": jax_result }
for name, result in results.items(): diff = np.abs(results["PyTorch"] - result) print(f"Difference with {name}: {diff}")
passed = all(np.allclose(results["PyTorch"], result, atol=tolerance) for name, result in results.items() if name != "PyTorch") print(f"Tests passed: {passed}")
PyTorch arcsinh result: [[[ 2.4642003 0.5497806] [ 1.1150385 2.838049 ] [-2.8458765 -2.8079052] [-2.9571593 2.5937262] [ 2.4172907 2.699194 ] [ 2.954745 2.4890096] [-0.7093974 2.4256508] [-2.730088 1.7526976] [-2.6627476 2.8816001]]] TensorFlow arcsinh result: [[[ 2.4642003 0.5497806] [ 1.1150385 2.838049 ] [-2.8458765 -2.8079052] [-2.9571593 2.5937262] [ 2.4172907 2.699194 ] [ 2.954745 2.4890096] [-0.7093974 2.4256508] [-2.730088 1.7526976] [-2.6627476 2.8816001]]] Keras arcsinh result: [[[ 2.4642003 0.5497806] [ 1.1150385 2.838049 ] [-2.8458765 -2.8079052] [-2.9571593 2.5937262] [ 2.4172907 2.699194 ] [ 2.954745 2.4890096] [-0.7093974 2.4256508] [-2.730088 1.7526976] [-2.6627476 2.8816001]]] JAX arcsinh result: [[[ 2.4642003 0.5497806] [ 1.1150384 2.838049 ] [-2.8458765 -2.8079052] [-2.9571593 2.5937262] [ 2.4172907 2.6991942] [ 2.9547448 2.4890094] [-0.7093974 2.4256508] [-2.730088 1.7526975] [-2.6627476 2.8816001]]] Difference with PyTorch: [[[0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.]]] Difference with TensorFlow: [[[0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.]]] Difference with Keras: [[[0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.]]] Difference with JAX: [[[0.0000000e+00 0.0000000e+00] [1.1920929e-07 0.0000000e+00] [0.0000000e+00 0.0000000e+00] [0.0000000e+00 0.0000000e+00] [0.0000000e+00 2.3841858e-07] [2.3841858e-07 2.3841858e-07] [0.0000000e+00 0.0000000e+00] [0.0000000e+00 1.1920929e-07] [0.0000000e+00 0.0000000e+00]]]
It appears JAX is differing from pytorch by 1ULP. Differences of this magnitude are to be expected when comparing different floating point implementations of particular functions.
Pinging @pearu who has been working recently on accuracy of outputs for functions like arcsin
and arcsinh
– what do you think about this? Is 1ULP difference unexpected across different frameworks in cases like this?
Quoting Wikipedia article on ULP: "Reputable libraries compute the basic transcendental functions to between 0.5 and about 1 ulp."
So, I think seldom 1 ULP difference in PyTorch and JAX functions is acceptable.
That said, I think we can do better. For instance, when using recent improvements in complex arcsinh, the chances for getting closer results between PyTorch and JAX are higher. For example:
>>> jax.numpy.arcsinh(jax.numpy.float32(9.572366714477539)) # error is 1 ULP
Array(2.954745, dtype=float32)
>>> jax.numpy.arcsinh(jax.numpy.complex64(9.572366714477539)) # error is 0 ULP
Array(2.9547448+0.j, dtype=complex64)
>>> numpy.arcsinh(numpy.float32(9.572366714477539))
2.9547448
>>> torch.arcsinh(torch.tensor([9.572366714477539], dtype=torch.float32)).numpy()
array([2.9547448], dtype=float32)
Finally, note that PyTorch arcsinh may not be a good reference function for arcsinh accuracy tests. For instance, there exists regions in complex plane where pytorch arcsinh results are incorrect or inaccurate, see the report here. Using multiprecision math libraries such as mpmath is a better approach for testing functions against inaccuracies.
Description
Issue: JAX's arcsinh results are inconsistent with those of PyTorch, TensorFlow, and Keras.
Description: When using the same input data, the results returned by JAX's arcsinh function exhibit significant discrepancies compared to other frameworks. Below are the results from the various libraries:
Summary of Differences: In JAX's results, the second value [1.1150384, 2.838049] differs slightly from the corresponding values in other frameworks, where the second value is [1.1150385, 2.838049]. In JAX's results, the first value [2.4172907, 2.6991942] differs slightly from other frameworks, where it appears as [2.4172907, 2.699194]. Other discrepancies can also be observed, but JAX's results exhibit noticeable differences in precision compared to PyTorch and TensorFlow. Recommendation: Please investigate and verify whether there is an issue with JAX's implementation or if there are inconsistencies in the implementation across different libraries to ensure result consistency among deep learning frameworks.
System info (python version, jaxlib version, accelerator, etc.)