Open hawkinsp opened 2 weeks ago
There are also failures on Linux ARM:
=================================== FAILURES ===================================
_______ FunctionAccuracyTest.testSuccessOnComplexPlane_square_complex64 ________
[gw15] linux -- Python 3.10.15 /usr/bin/python3.10
self = <lax_test.FunctionAccuracyTest testMethod=testSuccessOnComplexPlane_square_complex64>
name = 'square', dtype = <class 'numpy.complex64'>
@parameterized.named_parameters(
dict(testcase_name=f"_{name}_{dtype.__name__}", name=name, dtype=dtype)
for name, dtype in itertools.product(
_functions_on_complex_plane,
jtu.dtypes.supported([np.complex64, np.complex128]),
))
@jtu.skip_on_devices("tpu")
def testSuccessOnComplexPlane(self, name, dtype):
> self._testOnComplexPlaneWorker(name, dtype, 'success')
[tests/lax_test.py:4212](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4212):
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
[tests/lax_test.py:4438](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4438): in _testOnComplexPlaneWorker
self.assertAllClose(
[jax/_src/test_util.py:1263](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1263): in assertAllClose
self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol,
[jax/_src/test_util.py:1228](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1228): in assertArraysAllClose
_assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
[jax/_src/public_test_util.py:128](https://cs.corp.google.com/piper///depot/google3/jax/_src/public_test_util.py?l=128): in _assert_numpy_allclose
np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<function assert_allclose.<locals>.compare at 0xfffce5419480>, array([ nan, nan, nan, nan, -inf, -inf, -inf, -inf,...inf,
-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
-inf, -inf, nan], dtype=float32))
kwds = {'equal_nan': True, 'err_msg': 'square in ninfj.real, is_cpu=True is_cuda=False,\njax.numpy.square((-1.735863837493982..., expected (-inf-infj) [(-inf-infj)]', 'header': 'Not equal to tolerance rtol=1e-06, atol=1e-06', 'strict': False, ...}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E AssertionError:
E Not equal to tolerance rtol=1e-06, atol=1e-06
E square in ninfj.real, is_cpu=True is_cuda=False,
E jax.numpy.square((-1.735863837493982e+23-infj)) -> (nan+infj) [(nan+infj)], expected (-inf+infj) [(-inf+infj)]
E jax.numpy.square((-3.4028234663852886e+38-infj)) -> (nan+infj) [(nan+infj)], expected (-inf+infj) [(-inf+infj)]
E jax.numpy.square((-7.685595991398373e+30-infj)) -> (nan+infj) [(nan+infj)], expected (-inf+infj) [(-inf+infj)]
E jax.numpy.square((1.735863837493982e+23-infj)) -> (nan-infj) [(nan-infj)], expected (-inf-infj) [(-inf-infj)]
E jax.numpy.square((3.4028234663852886e+38-infj)) -> (nan-infj) [(nan-infj)], expected (-inf-infj) [(-inf-infj)]
E jax.numpy.square((7.685595991398373e+30-infj)) -> (nan-infj) [(nan-infj)], expected (-inf-infj) [(-inf-infj)]
E nan location mismatch:
E ACTUAL: array([ nan, nan, nan, nan, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
E -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, nan,
E nan, nan, nan], dtype=float32)
E DESIRED: array([ nan, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
E -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
E -inf, -inf, nan], dtype=float32)
/[usr/lib/python3.10/contextlib.py:79](https://cs.corp.google.com/piper///depot/google3/usr/lib/python3.10/contextlib.py?l=79): AssertionError
=========================== short test summary info ============================
FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_square_complex64
Update:
square
log1p
, arcsin
, and arcsinh
test failures on Linux ARM. In all cases, there is only one sample that fails:
jax.numpy.log1p((4.517187335295603e-08+1.1754943508222875e-38j)) -> (4.517187335295603e-08+0j)
jax.numpy.arcsin((1.1754943508222875e-38+4.517187335295603e-08j)) -> 4.517187335295603e-08j
jax.numpy.arcsinh((4.517187335295603e-08+1.1754943508222875e-38j)) -> (4.517187335295603e-08+0j)
(all other samples are within allowed range of errors). The expected results are:
>>> jax.numpy.log1p((4.517187335295603e-08+1.1754943508222875e-38j))
Array(4.5171873e-08+1.1754944e-38j, dtype=complex64, weak_type=True)
>>> jax.numpy.arcsin((1.1754943508222875e-38+4.517187335295603e-08j))
Array(1.1754944e-38+4.5171873e-08j, dtype=complex64, weak_type=True)
>>> jax.numpy.arcsinh((4.517187335295603e-08+1.1754943508222875e-38j))
Array(4.5171873e-08+1.1754944e-38j, dtype=complex64, weak_type=True)
It is possible that FTZ modes are different on Mac ARM and Linux ARM. Can this be verified? (see also feature request in https://github.com/jax-ml/jax/issues/24280)
Description
On Mac ARM, the following functions are failing in the nightly jax/jaxlib build:
@pearu can you PTAL?
System info (python version, jaxlib version, accelerator, etc.)
Mac ARM.