jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.6k stars 2.82k forks source link

Arcsinh Error Report #24276

Open LilyDong0127 opened 1 month ago

LilyDong0127 commented 1 month ago

Description

Issue: JAX's arcsinh results are inconsistent with those of PyTorch, TensorFlow, and Keras.

Description: When using the same input data, the results returned by JAX's arcsinh function exhibit significant discrepancies compared to other frameworks. Below are the results from the various libraries:

import torch
import tensorflow as tf
import jax.numpy as jnp
import numpy as np
from keras.layers import Lambda

# Input data
input_data = np.array([
    [
        [5.834500789642334, 0.5778982043266296],
        [1.360891342163086, 8.511932373046875],
        [-8.579278945922852, -8.257413864135742],
        [-9.59563159942627, 6.6523966789245605],
        [5.563135147094727, 7.400242805480957],
        [9.572366714477539, 5.983171463012695],
        [-0.7704126238822937, 5.610583305358887],
        [-7.634511470794678, 2.7984204292297363],
        [-7.132934093475342, 8.893378257751465]
    ]
])

# PyTorch arcsinh operation
def torch_arcsinh(x):
    return torch.asinh(torch.tensor(x, dtype=torch.float32))

# TensorFlow arcsinh operation
def tf_arcsinh(x):
    return tf.asinh(tf.convert_to_tensor(x, dtype=tf.float32))

# Keras arcsinh operation
def keras_arcsinh(x):
    return Lambda(lambda x: tf.math.asinh(x))(tf.convert_to_tensor(x, dtype=tf.float32))

# JAX arcsinh operation
def jax_arcsinh(x):
    return jnp.arcsinh(jnp.array(x))

# Calculate results
pytorch_result = torch_arcsinh(input_data)
tensorflow_result = tf_arcsinh(input_data)
keras_result = keras_arcsinh(input_data).numpy()  # Convert Keras result to numpy
jax_result = jax_arcsinh(input_data)

# Print results
print(f"PyTorch arcsinh result: {pytorch_result.detach().numpy()}")  # Detach to convert to numpy
print(f"TensorFlow arcsinh result: {tensorflow_result.numpy()}")
print(f"Keras arcsinh result: {keras_result}")
print(f"JAX arcsinh result: {jax_result}")

# Compare results
tolerance = 1e-7  # Set a tolerance for comparison
results = {
    "PyTorch": pytorch_result.detach().numpy(),
    "TensorFlow": tensorflow_result.numpy(),
    "Keras": keras_result,
    "JAX": jax_result
}

for name, result in results.items():
    diff = np.abs(results["PyTorch"] - result)
    print(f"Difference with {name}: {diff}")

# Check for passing criteria
passed = all(np.allclose(results["PyTorch"], result, atol=tolerance) for name, result in results.items() if name != "PyTorch")
print(f"Tests passed: {passed}")
PyTorch arcsinh result: [[[ 2.4642003  0.5497806]
  [ 1.1150385  2.838049 ]
  [-2.8458765 -2.8079052]
  [-2.9571593  2.5937262]
  [ 2.4172907  2.699194 ]
  [ 2.954745   2.4890096]
  [-0.7093974  2.4256508]
  [-2.730088   1.7526976]
  [-2.6627476  2.8816001]]]
TensorFlow arcsinh result: [[[ 2.4642003  0.5497806]
  [ 1.1150385  2.838049 ]
  [-2.8458765 -2.8079052]
  [-2.9571593  2.5937262]
  [ 2.4172907  2.699194 ]
  [ 2.954745   2.4890096]
  [-0.7093974  2.4256508]
  [-2.730088   1.7526976]
  [-2.6627476  2.8816001]]]
Keras arcsinh result: [[[ 2.4642003  0.5497806]
  [ 1.1150385  2.838049 ]
  [-2.8458765 -2.8079052]
  [-2.9571593  2.5937262]
  [ 2.4172907  2.699194 ]
  [ 2.954745   2.4890096]
  [-0.7093974  2.4256508]
  [-2.730088   1.7526976]
  [-2.6627476  2.8816001]]]
JAX arcsinh result: [[[ 2.4642003  0.5497806]
  [ 1.1150384  2.838049 ]
  [-2.8458765 -2.8079052]
  [-2.9571593  2.5937262]
  [ 2.4172907  2.6991942]
  [ 2.9547448  2.4890094]
  [-0.7093974  2.4256508]
  [-2.730088   1.7526975]
  [-2.6627476  2.8816001]]]

Summary of Differences: In JAX's results, the second value [1.1150384, 2.838049] differs slightly from the corresponding values in other frameworks, where the second value is [1.1150385, 2.838049]. In JAX's results, the first value [2.4172907, 2.6991942] differs slightly from other frameworks, where it appears as [2.4172907, 2.699194]. Other discrepancies can also be observed, but JAX's results exhibit noticeable differences in precision compared to PyTorch and TensorFlow. Recommendation: Please investigate and verify whether there is an issue with JAX's implementation or if there are inconsistencies in the implementation across different libraries to ensure result consistency among deep learning frameworks.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.9.19 | packaged by conda-forge | (main, Mar 20 2024, 12:38:46) [MSC v.1929 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', node='Lily的电脑', release='10', version='10.0.22631', machine='AMD64')
jakevdp commented 1 month ago

I suspect this is the same thing seen in #24275: JAX is computing in 32-bit, and you're comparing to platforms that compute in 64-bit.

LilyDong0127 commented 1 month ago

我怀疑这与#24275中看到的是同一件事:JAX 以 32 位计算,而您正在与以 64 位计算的平台进行比较。 But please look at my latest code, I use float32, but the final result of jax is obviously different from the results of other deep learning libraries? The results of the other three deep learning libraries are the same.


import torch
import tensorflow as tf
import jax.numpy as jnp
import numpy as np
from keras.layers import Lambda

Input data

input_data = np.array([ [ [5.834500789642334, 0.5778982043266296], [1.360891342163086, 8.511932373046875], [-8.579278945922852, -8.257413864135742], [-9.59563159942627, 6.6523966789245605], [5.563135147094727, 7.400242805480957], [9.572366714477539, 5.983171463012695], [-0.7704126238822937, 5.610583305358887], [-7.634511470794678, 2.7984204292297363], [-7.132934093475342, 8.893378257751465] ] ], dtype=np.float32) # Ensure input data is float32

PyTorch arcsinh operation

def torch_arcsinh(x): return torch.asinh(torch.tensor(x, dtype=torch.float32))

TensorFlow arcsinh operation

def tf_arcsinh(x): return tf.asinh(tf.convert_to_tensor(x, dtype=tf.float32))

Keras arcsinh operation

def keras_arcsinh(x): return Lambda(lambda x: tf.math.asinh(x))(tf.convert_to_tensor(x, dtype=tf.float32))

JAX arcsinh operation

def jax_arcsinh(x): return jnp.arcsinh(jnp.array(x, dtype=np.float32)) # Set dtype to float32

Calculate results

pytorch_result = torch_arcsinh(input_data) tensorflow_result = tf_arcsinh(input_data) keras_result = keras_arcsinh(input_data).numpy() # Convert Keras result to numpy jax_result = jax_arcsinh(input_data)

Print results

print(f"PyTorch arcsinh result: {pytorch_result.detach().numpy()}") # Detach to convert to numpy print(f"TensorFlow arcsinh result: {tensorflow_result.numpy()}") print(f"Keras arcsinh result: {keras_result}") print(f"JAX arcsinh result: {jax_result}")

Compare results

tolerance = 1e-7 # Set a tolerance for comparison results = { "PyTorch": pytorch_result.detach().numpy(), "TensorFlow": tensorflow_result.numpy(), "Keras": keras_result, "JAX": jax_result }

for name, result in results.items(): diff = np.abs(results["PyTorch"] - result) print(f"Difference with {name}: {diff}")

Check for passing criteria

passed = all(np.allclose(results["PyTorch"], result, atol=tolerance) for name, result in results.items() if name != "PyTorch") print(f"Tests passed: {passed}")

PyTorch arcsinh result: [[[ 2.4642003 0.5497806] [ 1.1150385 2.838049 ] [-2.8458765 -2.8079052] [-2.9571593 2.5937262] [ 2.4172907 2.699194 ] [ 2.954745 2.4890096] [-0.7093974 2.4256508] [-2.730088 1.7526976] [-2.6627476 2.8816001]]] TensorFlow arcsinh result: [[[ 2.4642003 0.5497806] [ 1.1150385 2.838049 ] [-2.8458765 -2.8079052] [-2.9571593 2.5937262] [ 2.4172907 2.699194 ] [ 2.954745 2.4890096] [-0.7093974 2.4256508] [-2.730088 1.7526976] [-2.6627476 2.8816001]]] Keras arcsinh result: [[[ 2.4642003 0.5497806] [ 1.1150385 2.838049 ] [-2.8458765 -2.8079052] [-2.9571593 2.5937262] [ 2.4172907 2.699194 ] [ 2.954745 2.4890096] [-0.7093974 2.4256508] [-2.730088 1.7526976] [-2.6627476 2.8816001]]] JAX arcsinh result: [[[ 2.4642003 0.5497806] [ 1.1150384 2.838049 ] [-2.8458765 -2.8079052] [-2.9571593 2.5937262] [ 2.4172907 2.6991942] [ 2.9547448 2.4890094] [-0.7093974 2.4256508] [-2.730088 1.7526975] [-2.6627476 2.8816001]]] Difference with PyTorch: [[[0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.]]] Difference with TensorFlow: [[[0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.]]] Difference with Keras: [[[0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.] [0. 0.]]] Difference with JAX: [[[0.0000000e+00 0.0000000e+00] [1.1920929e-07 0.0000000e+00] [0.0000000e+00 0.0000000e+00] [0.0000000e+00 0.0000000e+00] [0.0000000e+00 2.3841858e-07] [2.3841858e-07 2.3841858e-07] [0.0000000e+00 0.0000000e+00] [0.0000000e+00 1.1920929e-07] [0.0000000e+00 0.0000000e+00]]]

jakevdp commented 1 month ago

It appears JAX is differing from pytorch by 1ULP. Differences of this magnitude are to be expected when comparing different floating point implementations of particular functions.

jakevdp commented 1 month ago

Pinging @pearu who has been working recently on accuracy of outputs for functions like arcsin and arcsinh – what do you think about this? Is 1ULP difference unexpected across different frameworks in cases like this?

pearu commented 1 month ago

Quoting Wikipedia article on ULP: "Reputable libraries compute the basic transcendental functions to between 0.5 and about 1 ulp."

So, I think seldom 1 ULP difference in PyTorch and JAX functions is acceptable.

That said, I think we can do better. For instance, when using recent improvements in complex arcsinh, the chances for getting closer results between PyTorch and JAX are higher. For example:

>>> jax.numpy.arcsinh(jax.numpy.float32(9.572366714477539))  # error is 1 ULP
Array(2.954745, dtype=float32)
>>> jax.numpy.arcsinh(jax.numpy.complex64(9.572366714477539))  # error is 0 ULP
Array(2.9547448+0.j, dtype=complex64)
>>> numpy.arcsinh(numpy.float32(9.572366714477539))
2.9547448
>>> torch.arcsinh(torch.tensor([9.572366714477539], dtype=torch.float32)).numpy()
array([2.9547448], dtype=float32)

Finally, note that PyTorch arcsinh may not be a good reference function for arcsinh accuracy tests. For instance, there exists regions in complex plane where pytorch arcsinh results are incorrect or inaccurate, see the report here. Using multiprecision math libraries such as mpmath is a better approach for testing functions against inaccuracies.