jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.59k stars 2.82k forks source link

Nightly builds on Mac ARM fail complex function numerical tests #24787

Open hawkinsp opened 2 weeks ago

hawkinsp commented 2 weeks ago

Description

On Mac ARM, the following functions are failing in the nightly jax/jaxlib build:

________ FunctionAccuracyTest.testSuccessOnComplexPlane_log1p_complex64 ________
[gw7] darwin -- Python 3.10.13 /Users/kbuilder/.jax-pyenv/versions/3.10.13/bin/python
[tests/lax_test.py:4212](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4212): in testSuccessOnComplexPlane
    self._testOnComplexPlaneWorker(name, dtype, 'success')
[tests/lax_test.py:4438](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4438): in _testOnComplexPlaneWorker
    self.assertAllClose(
[jax/_src/test_util.py:1263](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1263): in assertAllClose
    self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol,
[jax/_src/test_util.py:1228](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1228): in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
[jax/_src/public_test_util.py:128](https://cs.corp.google.com/piper///depot/google3/jax/_src/public_test_util.py?l=128): in _assert_numpy_allclose
    np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
/[Users/kbuilder/.jax-pyenv/versions/3.10.13/lib/python3.10/contextlib.py:79](https://cs.corp.google.com/piper///depot/google3/Users/kbuilder/.jax-pyenv/versions/3.10.13/lib/python3.10/contextlib.py?l=79): in inner
    return func(*args, **kwds)
E   AssertionError: 
E   Not equal to tolerance rtol=1e-06, atol=1e-06
E   log1p in q1.imag, is_cpu=True is_cuda=False,
E   jax.numpy.log1p((4.517187335295603e-08+1.1754943508222875e-38j)) -> (4.517187335295603e-08+0j) [(0.7578582763671875+0j)], expected (4.517187335295603e-08+1.1754943508222875e-38j) [(0.7578582763671875+1.9721522630525295e-31j)]
E   jax.numpy.log1p((4.517187335295603e-08+5.204541205073606e-31j)) -> (4.517187335295603e-08+5.204541205073606e-31j) [(0.7578582763671875+8.731771197842018e-24j)], expected (4.517187335295603e-08+5.204540734875866e-31j) [(0.7578582763671875+8.731770408981113e-24j)]
E   jax.numpy.log1p((1.0202490673854366e-15+1.1754943508222875e-38j)) -> (1.0202490673854366e-15+0j) [(0.5743491649627686+0j)], expected (1.0202490673854366e-15+1.1754943508222875e-38j) [(0.5743491649627686+6.617444900424222e-24j)]
E   jax.numpy.log1p((4.517187335295603e-08+2.3043281796057088e-23j)) -> (4.517187335295603e-08+2.3043281796057088e-23j) [(0.7578582763671875+3.866021160413177e-16j)], expected (4.517187335295603e-08+2.3043280218335278e-23j) [(0.7578582763671875+3.866020895715381e-16j)]
E   jax.numpy.log1p((2.3043281796057088e-23+1.1754943508222875e-38j)) -> (2.3043281796057088e-23+0j) [(0.8705505728721619+0j)], expected (2.3043281796057088e-23+1.1754943508222875e-38j) [(0.8705505728721619+4.440892098500626e-16j)]
E   jax.numpy.log1p((5.204541205073606e-31+1.1754943508222875e-38j)) -> (5.204541205073606e-31+0j) [(0.6597539782524109+0j)], expected (5.204541205073606e-31+1.1754943508222875e-38j) [(0.6597539782524109+1.4901161193847656e-08j)]
E   jax.numpy.log1p((4.517187335295603e-08+4.517187335295603e-08j)) -> (4.517187335295603e-08+4.517187335295603e-08j) [(0.37892913818359375+0.37892913818359375j)], expected (4.517187335295603e-08+4.517186980024235e-08j) [(0.37892913818359375+0.37892910838127136j)]
E   jax.numpy.log1p((4.517187335295603e-08+2j)) -> (0.8047189712524414+1.1071487665176392j) [(0.4023594856262207+0.5535743832588196j)], expected (0.8047189712524414+1.1071486473083496j) [(0.4023594856262207+0.5535743236541748j)]
E   jax.numpy.log1p((1.1754943508222875e-38+1.1754943508222875e-38j)) -> (1.1754943508222875e-38+0j) [(0.5+0j)], expected (1.1754943508222875e-38+1.1754943508222875e-38j) [(0.5+0.5j)]
E   Mismatched elements: 1 / 121 (0.826%)
E   Max absolute difference among violations: 0.5
E   Max relative difference among violations: 1.
E    ACTUAL: array([[0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
E           0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
E           0.000000e+00, 0.000000e+00, 0.000000e+00],...
E    DESIRED: array([[5.000000e-01, 1.490116e-08, 4.440892e-16, 6.617445e-24,
E           1.972152e-31, 0.000000e+00, 0.000000e+00, 0.000000e+00,
E           0.000000e+00, 0.000000e+00, 0.000000e+00],...
_______ FunctionAccuracyTest.testSuccessOnComplexPlane_square_complex64 ________
[gw7] darwin -- Python 3.10.13 /Users/kbuilder/.jax-pyenv/versions/3.10.13/bin/python
[tests/lax_test.py:4212](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4212): in testSuccessOnComplexPlane
    self._testOnComplexPlaneWorker(name, dtype, 'success')
[tests/lax_test.py:4438](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4438): in _testOnComplexPlaneWorker
    self.assertAllClose(
[jax/_src/test_util.py:1263](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1263): in assertAllClose
    self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol,
[jax/_src/test_util.py:1228](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1228): in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
[jax/_src/public_test_util.py:128](https://cs.corp.google.com/piper///depot/google3/jax/_src/public_test_util.py?l=128): in _assert_numpy_allclose
    np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
/[Users/kbuilder/.jax-pyenv/versions/3.10.13/lib/python3.10/contextlib.py:79](https://cs.corp.google.com/piper///depot/google3/Users/kbuilder/.jax-pyenv/versions/3.10.13/lib/python3.10/contextlib.py?l=79): in inner
    return func(*args, **kwds)
E   AssertionError: 
E   Not equal to tolerance rtol=1e-06, atol=1e-06
E   square in ninfj.real, is_cpu=True is_cuda=False,
E   jax.numpy.square((-1.735863837493982e+23-infj)) -> (nan+infj) [(nan+infj)], expected (-inf+infj) [(-inf+infj)]
E   jax.numpy.square((-3.4028234663852886e+38-infj)) -> (nan+infj) [(nan+infj)], expected (-inf+infj) [(-inf+infj)]
E   jax.numpy.square((-7.685595991398373e+30-infj)) -> (nan+infj) [(nan+infj)], expected (-inf+infj) [(-inf+infj)]
E   jax.numpy.square((1.735863837493982e+23-infj)) -> (nan-infj) [(nan-infj)], expected (-inf-infj) [(-inf-infj)]
E   jax.numpy.square((3.4028234663852886e+38-infj)) -> (nan-infj) [(nan-infj)], expected (-inf-infj) [(-inf-infj)]
E   jax.numpy.square((7.685595991398373e+30-infj)) -> (nan-infj) [(nan-infj)], expected (-inf-infj) [(-inf-infj)]
E   nan location mismatch:
E    ACTUAL: array([ nan,  nan,  nan,  nan, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
E          -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,  nan,
E           nan,  nan,  nan], dtype=float32)
E    DESIRED: array([ nan, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
E          -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
E          -inf, -inf,  nan], dtype=float32)
_______ FunctionAccuracyTest.testSuccessOnComplexPlane_arcsin_complex64 ________
[gw6] darwin -- Python 3.10.13 /Users/kbuilder/.jax-pyenv/versions/3.10.13/bin/python
[tests/lax_test.py:4212](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4212): in testSuccessOnComplexPlane
    self._testOnComplexPlaneWorker(name, dtype, 'success')
[tests/lax_test.py:4438](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4438): in _testOnComplexPlaneWorker
    self.assertAllClose(
[jax/_src/test_util.py:1263](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1263): in assertAllClose
    self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol,
[jax/_src/test_util.py:1228](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1228): in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
[jax/_src/public_test_util.py:128](https://cs.corp.google.com/piper///depot/google3/jax/_src/public_test_util.py?l=128): in _assert_numpy_allclose
    np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
/[Users/kbuilder/.jax-pyenv/versions/3.10.13/lib/python3.10/contextlib.py:79](https://cs.corp.google.com/piper///depot/google3/Users/kbuilder/.jax-pyenv/versions/3.10.13/lib/python3.10/contextlib.py?l=79): in inner
    return func(*args, **kwds)
E   AssertionError: 
E   Not equal to tolerance rtol=1e-06, atol=1e-06
E   arcsin in q1.real, is_cpu=True is_cuda=False,
E   jax.numpy.arcsin((1.1754943508222875e-38+4.517187335295603e-08j)) -> 4.517187335295603e-08j [0.7578582763671875j], expected (1.1754943508222875e-38+4.517187335295603e-08j) [(1.9721522630525295e-31+0.7578582763671875j)]
E   jax.numpy.arcsin((1.1754943508222875e-38+1.0202490673854366e-15j)) -> 1.0202490673854366e-15j [0.5743491649627686j], expected (1.1754943508222875e-38+1.0202490673854366e-15j) [(6.617444900424222e-24+0.5743491649627686j)]
E   jax.numpy.arcsin((1.0202490673854366e-15+2j)) -> (4.562692259942242e-16+1.4436354637145996j) [(2.281346129971121e-16+0.7218177318572998j)], expected (4.562692789337834e-16+1.4436354637145996j) [(2.281346394668917e-16+0.7218177318572998j)]
E   jax.numpy.arcsin((1.1754943508222875e-38+2.3043281796057088e-23j)) -> 2.3043281796057088e-23j [0.8705505728721619j], expected (1.1754943508222875e-38+2.3043281796057088e-23j) [(4.440892098500626e-16+0.8705505728721619j)]
E   jax.numpy.arcsin((1.1754943508222875e-38+5.204541205073606e-31j)) -> 5.204541205073606e-31j [0.6597539782524109j], expected (1.1754943508222875e-38+5.204541205073606e-31j) [(1.4901161193847656e-08+0.6597539782524109j)]
E   jax.numpy.arcsin((4.517187335295603e-08+4.517187335295603e-08j)) -> (4.5171876905669706e-08+4.517187335295603e-08j) [(0.37892916798591614+0.37892913818359375j)], expected (4.517187335295603e-08+4.517187335295603e-08j) [(0.37892913818359375+0.37892913818359375j)]
E   jax.numpy.arcsin((4.517187335295603e-08+1.0202490673854366e-15j)) -> (4.5171876905669706e-08+1.0202490673854366e-15j) [(0.7578583359718323+1.7116938977324025e-08j)], expected (4.517187335295603e-08+1.0202490673854366e-15j) [(0.7578582763671875+1.7116938977324025e-08j)]
E   jax.numpy.arcsin((4.517187335295603e-08+1.1754943508222875e-38j)) -> (4.5171876905669706e-08+1.175494490952134e-38j) [(0.7578583359718323+1.9721524981513997e-31j)], expected (4.517187335295603e-08+1.1754943508222875e-38j) [(0.7578582763671875+1.9721522630525295e-31j)]
E   jax.numpy.arcsin((4.517187335295603e-08+2.3043281796057088e-23j)) -> (4.5171876905669706e-08+2.30432833737789e-23j) [(0.7578583359718323+3.866021425110973e-16j)], expected (4.517187335295603e-08+2.3043281796057088e-23j) [(0.7578582763671875+3.866021160413177e-16j)]
E   jax.numpy.arcsin((4.517187335295603e-08+5.204541205073606e-31j)) -> (4.5171876905669706e-08+5.204541675271346e-31j) [(0.7578583359718323+8.731771986702923e-24j)], expected (4.517187335295603e-08+5.204541205073606e-31j) [(0.7578582763671875+8.731771197842018e-24j)]
E   jax.numpy.arcsin((1.1754943508222875e-38+1.1754943508222875e-38j)) -> 1.1754943508222875e-38j [0.5j], expected (1.1754943508222875e-38+1.1754943508222875e-38j) [(0.5+0.5j)]
E   Mismatched elements: 1 / 121 (0.826%)
E   Max absolute difference among violations: 0.5
E   Max relative difference among violations: 1.
E    ACTUAL: array([[0.000000e+00, 6.597540e-01, 8.705506e-01, 5.743492e-01,
E           7.578583e-01, 3.926991e-01, 4.908739e-02, 2.454369e-02,
E           2.454369e-02, 1.227185e-02, 1.227185e-02],...
E    DESIRED: array([[5.000000e-01, 6.597540e-01, 8.705506e-01, 5.743492e-01,
E           7.578583e-01, 3.926991e-01, 4.908739e-02, 2.454369e-02,
E           2.454369e-02, 1.227185e-02, 1.227185e-02],...
_______ FunctionAccuracyTest.testSuccessOnComplexPlane_arcsinh_complex64 _______
[gw6] darwin -- Python 3.10.13 /Users/kbuilder/.jax-pyenv/versions/3.10.13/bin/python
[tests/lax_test.py:4212](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4212): in testSuccessOnComplexPlane
    self._testOnComplexPlaneWorker(name, dtype, 'success')
[tests/lax_test.py:4438](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4438): in _testOnComplexPlaneWorker
    self.assertAllClose(
[jax/_src/test_util.py:1263](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1263): in assertAllClose
    self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol,
[jax/_src/test_util.py:1228](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1228): in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
[jax/_src/public_test_util.py:128](https://cs.corp.google.com/piper///depot/google3/jax/_src/public_test_util.py?l=128): in _assert_numpy_allclose
    np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
/[Users/kbuilder/.jax-pyenv/versions/3.10.13/lib/python3.10/contextlib.py:79](https://cs.corp.google.com/piper///depot/google3/Users/kbuilder/.jax-pyenv/versions/3.10.13/lib/python3.10/contextlib.py?l=79): in inner
    return func(*args, **kwds)
E   AssertionError: 
E   Not equal to tolerance rtol=1e-06, atol=1e-06
E   arcsinh in q1.imag, is_cpu=True is_cuda=False,
E   jax.numpy.arcsinh((4.517187335295603e-08+1.1754943508222875e-38j)) -> (4.517187335295603e-08+0j) [(0.7578582763671875+0j)], expected (4.517187335295603e-08+1.1754943508222875e-38j) [(0.7578582763671875+1.9721522630525295e-31j)]
E   jax.numpy.arcsinh((1.0202490673854366e-15+1.1754943508222875e-38j)) -> (1.0202490673854366e-15+0j) [(0.5743491649627686+0j)], expected (1.0202490673854366e-15+1.1754943508222875e-38j) [(0.5743491649627686+6.617444900424222e-24j)]
E   jax.numpy.arcsinh((2+1.0202490673854366e-15j)) -> (1.4436354637145996+4.562692259942242e-16j) [(0.7218177318572998+2.281346129971121e-16j)], expected (1.4436354637145996+4.562692789337834e-16j) [(0.7218177318572998+2.281346394668917e-16j)]
E   jax.numpy.arcsinh((2.3043281796057088e-23+1.1754943508222875e-38j)) -> (2.3043281796057088e-23+0j) [(0.8705505728721619+0j)], expected (2.3043281796057088e-23+1.1754943508222875e-38j) [(0.8705505728721619+4.440892098500626e-16j)]
E   jax.numpy.arcsinh((5.204541205073606e-31+1.1754943508222875e-38j)) -> (5.204541205073606e-31+0j) [(0.6597539782524109+0j)], expected (5.204541205073606e-31+1.1754943508222875e-38j) [(0.6597539782524109+1.4901161193847656e-08j)]
E   jax.numpy.arcsinh((4.517187335295603e-08+4.517187335295603e-08j)) -> (4.517187335295603e-08+4.5171876905669706e-08j) [(0.37892913818359375+0.37892916798591614j)], expected (4.517187335295603e-08+4.517187335295603e-08j) [(0.37892913818359375+0.37892913818359375j)]
E   jax.numpy.arcsinh((1.0202490673854366e-15+4.517187335295603e-08j)) -> (1.0202490673854366e-15+4.5171876905669706e-08j) [(1.7116938977324025e-08+0.7578583359718323j)], expected (1.0202490673854366e-15+4.517187335295603e-08j) [(1.7116938977324025e-08+0.7578582763671875j)]
E   jax.numpy.arcsinh((1.1754943508222875e-38+4.517187335295603e-08j)) -> (1.175494490952134e-38+4.5171876905669706e-08j) [(1.9721524981513997e-31+0.7578583359718323j)], expected (1.1754943508222875e-38+4.517187335295603e-08j) [(1.9721522630525295e-31+0.7578582763671875j)]
E   jax.numpy.arcsinh((2.3043281796057088e-23+4.517187335295603e-08j)) -> (2.30432833737789e-23+4.5171876905669706e-08j) [(3.866021425110973e-16+0.7578583359718323j)], expected (2.3043281796057088e-23+4.517187335295603e-08j) [(3.866021160413177e-16+0.7578582763671875j)]
E   jax.numpy.arcsinh((5.204541205073606e-31+4.517187335295603e-08j)) -> (5.204541675271346e-31+4.5171876905669706e-08j) [(8.731771986702923e-24+0.7578583359718323j)], expected (5.204541205073606e-31+4.517187335295603e-08j) [(8.731771197842018e-24+0.7578582763671875j)]
E   jax.numpy.arcsinh((1.1754943508222875e-38+1.1754943508222875e-38j)) -> (1.1754943508222875e-38+0j) [(0.5+0j)], expected (1.1754943508222875e-38+1.1754943508222875e-38j) [(0.5+0.5j)]
E   Mismatched elements: 1 / 121 (0.826%)
E   Max absolute difference among violations: 0.5
E   Max relative difference among violations: 1.
E    ACTUAL: array([[0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
E           0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
E           0.000000e+00, 0.000000e+00, 0.000000e+00],...
E    DESIRED: array([[5.000000e-01, 1.490116e-08, 4.440892e-16, 6.617445e-24,
E           1.972152e-31, 0.000000e+00, 0.000000e+00, 0.000000e+00,
E           0.000000e+00, 0.000000e+00, 0.000000e+00],...
=========================== short test summary info ============================
FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_log1p_complex64
FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_square_complex64
FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_arcsin_complex64
FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_arcsinh_complex64

@pearu can you PTAL?

System info (python version, jaxlib version, accelerator, etc.)

Mac ARM.

hawkinsp commented 2 weeks ago

There are also failures on Linux ARM:

=================================== FAILURES ===================================
_______ FunctionAccuracyTest.testSuccessOnComplexPlane_square_complex64 ________
[gw15] linux -- Python 3.10.15 /usr/bin/python3.10

self = <lax_test.FunctionAccuracyTest testMethod=testSuccessOnComplexPlane_square_complex64>
name = 'square', dtype = <class 'numpy.complex64'>

    @parameterized.named_parameters(
      dict(testcase_name=f"_{name}_{dtype.__name__}", name=name, dtype=dtype)
      for name, dtype in itertools.product(
          _functions_on_complex_plane,
          jtu.dtypes.supported([np.complex64, np.complex128]),
      ))
    @jtu.skip_on_devices("tpu")
    def testSuccessOnComplexPlane(self, name, dtype):
>     self._testOnComplexPlaneWorker(name, dtype, 'success')

[tests/lax_test.py:4212](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4212): 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
[tests/lax_test.py:4438](https://cs.corp.google.com/piper///depot/google3/tests/lax_test.py?l=4438): in _testOnComplexPlaneWorker
    self.assertAllClose(
[jax/_src/test_util.py:1263](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1263): in assertAllClose
    self.assertArraysAllClose(x, y, check_dtypes=False, atol=atol, rtol=rtol,
[jax/_src/test_util.py:1228](https://cs.corp.google.com/piper///depot/google3/jax/_src/test_util.py?l=1228): in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
[jax/_src/public_test_util.py:128](https://cs.corp.google.com/piper///depot/google3/jax/_src/public_test_util.py?l=128): in _assert_numpy_allclose
    np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<function assert_allclose.<locals>.compare at 0xfffce5419480>, array([ nan,  nan,  nan,  nan, -inf, -inf, -inf, -inf,...inf,
       -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
       -inf, -inf,  nan], dtype=float32))
kwds = {'equal_nan': True, 'err_msg': 'square in ninfj.real, is_cpu=True is_cuda=False,\njax.numpy.square((-1.735863837493982..., expected (-inf-infj) [(-inf-infj)]', 'header': 'Not equal to tolerance rtol=1e-06, atol=1e-06', 'strict': False, ...}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=1e-06, atol=1e-06
E           square in ninfj.real, is_cpu=True is_cuda=False,
E           jax.numpy.square((-1.735863837493982e+23-infj)) -> (nan+infj) [(nan+infj)], expected (-inf+infj) [(-inf+infj)]
E           jax.numpy.square((-3.4028234663852886e+38-infj)) -> (nan+infj) [(nan+infj)], expected (-inf+infj) [(-inf+infj)]
E           jax.numpy.square((-7.685595991398373e+30-infj)) -> (nan+infj) [(nan+infj)], expected (-inf+infj) [(-inf+infj)]
E           jax.numpy.square((1.735863837493982e+23-infj)) -> (nan-infj) [(nan-infj)], expected (-inf-infj) [(-inf-infj)]
E           jax.numpy.square((3.4028234663852886e+38-infj)) -> (nan-infj) [(nan-infj)], expected (-inf-infj) [(-inf-infj)]
E           jax.numpy.square((7.685595991398373e+30-infj)) -> (nan-infj) [(nan-infj)], expected (-inf-infj) [(-inf-infj)]
E           nan location mismatch:
E            ACTUAL: array([ nan,  nan,  nan,  nan, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
E                  -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,  nan,
E                   nan,  nan,  nan], dtype=float32)
E            DESIRED: array([ nan, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
E                  -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
E                  -inf, -inf,  nan], dtype=float32)

/[usr/lib/python3.10/contextlib.py:79](https://cs.corp.google.com/piper///depot/google3/usr/lib/python3.10/contextlib.py?l=79): AssertionError
=========================== short test summary info ============================
FAILED tests/lax_test.py::FunctionAccuracyTest::testSuccessOnComplexPlane_square_complex64
pearu commented 1 week ago

Update: