google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.56k stars 166 forks source link

LinearSolveTest.test_solve_sparse fails with jax 0.4.26 #960

Closed GaetanLepage closed 2 months ago

GaetanLepage commented 2 months ago

Context: updating jax in nixpkgs: https://github.com/NixOS/nixpkgs/pull/291705#issuecomment-2095894365

One of the optax tests fail when ran with the latest jax (0.4.26):

============================= test session starts ==============================
platform linux -- Python 3.11.9, pytest-8.1.1, pluggy-1.4.0
rootdir: /build/source
plugins: xdist-3.5.0
48 workers [561 items]   m
...s.................................................................... [ 12%]
........................................................................ [ 25%]
.............................................s......s............s...... [ 38%]
............s........s.................................................. [ 51%]
.................F...................................................... [ 64%]
........................................................................ [ 77%]
........................................................................ [ 89%]
.........................................................                [100%]
=================================== FAILURES ===================================
______________________ LinearSolveTest.test_solve_sparse _______________________
[gw24] linux -- Python 3.11.9 /nix/store/lpi16513bai8kg2bd841745vzk72475x-python3-3.11.9/bin/python3.11

self = <linear_solve_test.LinearSolveTest testMethod=test_solve_sparse>

    def test_solve_sparse(self):
      rng = onp.random.RandomState(0)

      # Matrix case.
      A = rng.randn(5, 5)
      b = rng.randn(5)

      def matvec(x):
        return jnp.dot(A, x)

      x = linear_solve.solve_lu(matvec, b)
      x2 = linear_solve.solve_normal_cg(matvec, b)
      x3 = linear_solve.solve_gmres(matvec, b)
      x4 = linear_solve.solve_bicgstab(matvec, b)
      x5 = linear_solve.solve_iterative_refinement(matvec, b)
      x6 = linear_solve.solve_qr(matvec, b)

      self.assertArraysAllClose(x, x2, atol=1e-4)
      self.assertArraysAllClose(x, x3, atol=1e-4)
>     self.assertArraysAllClose(x, x4, atol=1e-4)

tests/linear_solve_test.py:133: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
jaxopt/_src/test_util.py:292: in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
jaxopt/_src/test_util.py:262: in _assert_numpy_allclose
    onp.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<function assert_allclose.<locals>.compare at 0x7ffcd857af20>, array([-6.9443436, -1.9871655,  7.7470713,  7.654949 ,...87526],
      dtype=float32), array([-6.9444494, -1.9872105,  7.7471952,  7.655079 , -7.0388584],
      dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-06, atol=0.0001', 'verbose': True}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=1e-06, atol=0.0001
E           
E           Mismatched elements: 2 / 5 (40%)
E           Max absolute difference: 0.0001297
E           Max relative difference: 2.267556e-05
E            x: array([-6.944344, -1.987165,  7.747071,  7.654949, -7.038753],
E                 dtype=float32)
E            y: array([-6.944449, -1.987211,  7.747195,  7.655079, -7.038858],
E                 dtype=float32)

/nix/store/lpi16513bai8kg2bd841745vzk72475x-python3-3.11.9/lib/python3.11/contextlib.py:81: AssertionError

Any idea ?

GaetanLepage commented 2 months ago

Wrong repo. This issue is actually happening in jaxopt. Sorry for the inconvenience.