Closed GaetanLepage closed 2 months ago
Context: updating jax in nixpkgs: https://github.com/NixOS/nixpkgs/pull/291705#issuecomment-2095894365
jax
One of the optax tests fail when ran with the latest jax (0.4.26):
optax
============================= test session starts ============================== platform linux -- Python 3.11.9, pytest-8.1.1, pluggy-1.4.0 rootdir: /build/source plugins: xdist-3.5.0 48 workers [561 items] m ...s.................................................................... [ 12%] ........................................................................ [ 25%] .............................................s......s............s...... [ 38%] ............s........s.................................................. [ 51%] .................F...................................................... [ 64%] ........................................................................ [ 77%] ........................................................................ [ 89%] ......................................................... [100%] =================================== FAILURES =================================== ______________________ LinearSolveTest.test_solve_sparse _______________________ [gw24] linux -- Python 3.11.9 /nix/store/lpi16513bai8kg2bd841745vzk72475x-python3-3.11.9/bin/python3.11 self = <linear_solve_test.LinearSolveTest testMethod=test_solve_sparse> def test_solve_sparse(self): rng = onp.random.RandomState(0) # Matrix case. A = rng.randn(5, 5) b = rng.randn(5) def matvec(x): return jnp.dot(A, x) x = linear_solve.solve_lu(matvec, b) x2 = linear_solve.solve_normal_cg(matvec, b) x3 = linear_solve.solve_gmres(matvec, b) x4 = linear_solve.solve_bicgstab(matvec, b) x5 = linear_solve.solve_iterative_refinement(matvec, b) x6 = linear_solve.solve_qr(matvec, b) self.assertArraysAllClose(x, x2, atol=1e-4) self.assertArraysAllClose(x, x3, atol=1e-4) > self.assertArraysAllClose(x, x4, atol=1e-4) tests/linear_solve_test.py:133: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ jaxopt/_src/test_util.py:292: in assertArraysAllClose _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg) jaxopt/_src/test_util.py:262: in _assert_numpy_allclose onp.testing.assert_allclose(a, b, **kw, err_msg=err_msg) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ args = (<function assert_allclose.<locals>.compare at 0x7ffcd857af20>, array([-6.9443436, -1.9871655, 7.7470713, 7.654949 ,...87526], dtype=float32), array([-6.9444494, -1.9872105, 7.7471952, 7.655079 , -7.0388584], dtype=float32)) kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-06, atol=0.0001', 'verbose': True} @wraps(func) def inner(*args, **kwds): with self._recreate_cm(): > return func(*args, **kwds) E AssertionError: E Not equal to tolerance rtol=1e-06, atol=0.0001 E E Mismatched elements: 2 / 5 (40%) E Max absolute difference: 0.0001297 E Max relative difference: 2.267556e-05 E x: array([-6.944344, -1.987165, 7.747071, 7.654949, -7.038753], E dtype=float32) E y: array([-6.944449, -1.987211, 7.747195, 7.655079, -7.038858], E dtype=float32) /nix/store/lpi16513bai8kg2bd841745vzk72475x-python3-3.11.9/lib/python3.11/contextlib.py:81: AssertionError
Any idea ?
Wrong repo. This issue is actually happening in jaxopt. Sorry for the inconvenience.
Context: updating
jax
in nixpkgs: https://github.com/NixOS/nixpkgs/pull/291705#issuecomment-2095894365One of the
optax
tests fail when ran with the latestjax
(0.4.26):Any idea ?