google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
922 stars 64 forks source link

unit test failures on aarch64 linux with scipy 1.12 #577

Open ghost opened 8 months ago

ghost commented 8 months ago

jaxopt 0.8.3 with patch https://github.com/google/jaxopt/pull/574 is failing the following unit tests after updating to scipy 1.12

=================================== FAILURES ===================================
__________________ LbfgsTest.test_binary_logit_log_likelihood __________________
[gw45] linux -- Python 3.11.7 /nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/bin/python3.11
self = <lbfgs_test.LbfgsTest testMethod=test_binary_logit_log_likelihood>
    def test_binary_logit_log_likelihood(self):
      # See issue #409
      rng = jax.random.PRNGKey(42)
      N = 1000
      beta = jnp.array([[0.5,0.5]]).T
      income = jax.random.normal(rng, shape=(N,1))
      x = jnp.hstack([jnp.ones((N,1)), income])

      def simulate_binary_logit(x, beta):
        beta = beta.reshape(-1,1)
        N = x.shape[0]
        J = beta.shape[0]

        epsilon = jax.random.gumbel(rng,shape =(N,J))
        Beta_augmented = jnp.hstack([beta, jnp.zeros_like(beta)])
        utility = x @ Beta_augmented + epsilon

        choice_idx = onp.argmax(utility, axis=1)
        return (choice_idx).reshape(-1,1)

      y = simulate_binary_logit(x, beta)
      y = jnp.ravel(y)

      # numpy version
      def binary_logit_log_likelihood(beta, y,x):
        lambda_xb = onp.exp(x@beta) / (1 + onp.exp(x@beta))
        ll_i = y * onp.log(lambda_xb) + (1-y) * onp.log(1-lambda_xb)
        ll = -onp.sum(ll_i)
        return ll

      # jax version
      def binary_logit_log_likelihood_jax(beta, y, x):
        lambda_xb = jnp.exp(x@beta) / (1 + jnp.exp(x@beta))
        ll_i = y * jnp.log(lambda_xb) + (1-y) * jnp.log(1-lambda_xb)
        ll = -jnp.sum(ll_i)
        return ll

      beta_init = jnp.array([0.01,0.01])

      # using scipy
      scipy_res = scipy_opt.minimize(
        fun=binary_logit_log_likelihood,
        args=(onp.asarray(y),onp.asarray(x)),
        x0 = (onp.asarray(beta_init)), method='BFGS'
      ).x

      # using jaxopt
      solver = LBFGS(fun=binary_logit_log_likelihood_jax, maxiter=100,
                     linesearch="zoom", maxls=10, tol=1e-12)
      jaxopt_res = solver.run(beta_init, y, x).params

      # comparison
      scipy_val = binary_logit_log_likelihood(scipy_res,
                                              onp.asarray(y),
                                              onp.asarray(x))
      jaxopt_val = binary_logit_log_likelihood(jaxopt_res, y, x)
>     self.assertArraysAllClose(scipy_val, jaxopt_val)
tests/lbfgs_test.py:422: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
jaxopt/_src/test_util.py:292: in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
jaxopt/_src/test_util.py:262: in _assert_numpy_allclose
    onp.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
args = (<function assert_allclose.<locals>.compare at 0xfffeb82b56c0>, array(636.76217796), array(636.7615, dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-06, atol=1e-06', 'verbose': True}
    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=1e-06, atol=1e-06
E           
E           Mismatched elements: 1 / 1 (100%)
E           Max absolute difference: 0.00070335
E           Max relative difference: 1.10457124e-06
E            x: array(636.762178)
E            y: array(636.7615, dtype=float32)
/nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/lib/python3.11/contextlib.py:81: AssertionError
______________________ LinearSolveTest.test_solve_sparse _______________________
[gw32] linux -- Python 3.11.7 /nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/bin/python3.11
self = <linear_solve_test.LinearSolveTest testMethod=test_solve_sparse>
    def test_solve_sparse(self):
      rng = onp.random.RandomState(0)

      # Matrix case.
      A = rng.randn(5, 5)
      b = rng.randn(5)

      def matvec(x):
        return jnp.dot(A, x)

      x = linear_solve.solve_lu(matvec, b)
      x2 = linear_solve.solve_normal_cg(matvec, b)
      x3 = linear_solve.solve_gmres(matvec, b)
      x4 = linear_solve.solve_bicgstab(matvec, b)
      x5 = linear_solve.solve_iterative_refinement(matvec, b)
      x6 = linear_solve.solve_qr(matvec, b)

      self.assertArraysAllClose(x, x2, atol=1e-4)
>     self.assertArraysAllClose(x, x3, atol=1e-4)
tests/linear_solve_test.py:132: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
jaxopt/_src/test_util.py:292: in assertArraysAllClose
    _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
jaxopt/_src/test_util.py:262: in _assert_numpy_allclose
    onp.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
args = (<function assert_allclose.<locals>.compare at 0xfffe946f58a0>, array([-6.9443398, -1.9871643,  7.747069 ,  7.654946 ,...875  ],
      dtype=float32), array([-6.944449 , -1.9872042,  7.747199 ,  7.655077 , -7.038865 ],
      dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-06, atol=0.0001', 'verbose': True}
    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=1e-06, atol=0.0001
E           
E           Mismatched elements: 4 / 5 (80%)
E           Max absolute difference: 0.00013113
E           Max relative difference: 2.009613e-05
E            x: array([-6.94434 , -1.987164,  7.747069,  7.654946, -7.03875 ],
E                 dtype=float32)
E            y: array([-6.944449, -1.987204,  7.747199,  7.655077, -7.038865],
E                 dtype=float32)
/nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/lib/python3.11/contextlib.py:81: AssertionError
____________ PolyakSgdTest.test_logreg_with_intercept_manual_loop3 _____________
[gw13] linux -- Python 3.11.7 /nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/bin/python3.11
self = <polyak_sgd_test.PolyakSgdTest testMethod=test_logreg_with_intercept_manual_loop3>
momentum = 0.9, sps_variant = 'SPS+'
    @parameterized.product(momentum=[0.0, 0.9], sps_variant=['SPS_max', 'SPS+'])
    def test_logreg_with_intercept_manual_loop(self, momentum, sps_variant):
      x, y = datasets.make_classification(n_samples=10, n_features=5, n_classes=3,
                                          n_informative=3, random_state=0)
      data = (x, y)
      l2reg = 0.1
      # fun(params, l2reg, data)
      fun = objective.l2_multiclass_logreg_with_intercept
      n_classes = len(jnp.unique(y))

      w_init = jnp.zeros((x.shape[1], n_classes))
      b_init = jnp.zeros(n_classes)
      params = (w_init, b_init)

      opt = PolyakSGD(
          fun=fun, fun_min=0.6975, momentum=momentum, variant=sps_variant
      )
      error_init = opt.l2_optimality_error(params, l2reg=l2reg, data=data)

      state = opt.init_state(params, l2reg=l2reg, data=data)
      for _ in range(200):
        params, state = opt.update(params, state, l2reg=l2reg, data=data)

      # Check optimality conditions.
      error = opt.l2_optimality_error(params, l2reg=l2reg, data=data)
>     self.assertLessEqual(error / error_init, 0.02)
E     AssertionError: Array(0.02369377, dtype=float32) not less than or equal to 0.02
tests/polyak_sgd_test.py:79: AssertionError
=============================== warnings summary ===============================
jaxopt/_src/osqp.py:299
  /build/source/jaxopt/_src/osqp.py:299: DeprecationWarning: invalid escape sequence '\m'
    """Operator Splitting Solver for Quadratic Programs.
tests/isotonic_test.py::IsotonicPavTest::test_compare_with_sklearn0
tests/isotonic_test.py::IsotonicPavTest::test_compare_with_sklearn0
tests/isotonic_test.py::IsotonicPavTest::test_compare_with_sklearn1
tests/isotonic_test.py::IsotonicPavTest::test_compare_with_sklearn1
tests/isotonic_test.py::IsotonicPavTest::test_output_shape_and_dtype
tests/isotonic_test.py::IsotonicPavTest::test_vmap
tests/isotonic_test.py::IsotonicPavTest::test_gradient1
tests/isotonic_test.py::IsotonicPavTest::test_gradient0
tests/isotonic_test.py::IsotonicPavTest::test_gradient_min_max
  /build/source/jaxopt/_src/isotonic.py:94: UserWarning: Numba could not be imported. Code will run much more slowly. To install, run 'pip install numba'.
    warnings.warn(
tests/lbfgs_test.py::LbfgsTest::test_against_scipy1
tests/lbfgs_test.py::LbfgsTest::test_against_scipy3
tests/lbfgs_test.py::LbfgsTest::test_against_scipy0
tests/lbfgs_test.py::LbfgsTest::test_against_scipy4
tests/lbfgs_test.py::LbfgsTest::test_minimize_bad_initial_values
tests/lbfgs_test.py::LbfgsTest::test_against_scipy2
  /build/source/jaxopt/_src/lbfgs.py:119: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    fun = lambda leaf: jnp.zeros((history_size,) + leaf.shape, dtype=leaf.dtype)
tests/linear_solve_test.py::LinearSolveTest::test_solve_1d
tests/linear_solve_test.py::LinearSolveTest::test_solve_dense
tests/linear_solve_test.py::LinearSolveTest::test_solve_sparse
tests/linear_solve_test.py::LinearSolveTest::test_solve_sparse_ridge
  /build/source/jaxopt/_src/linear_solve.py:31: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    x = jnp.zeros(shape, dtype)
tests/levenberg_marquardt_test.py::LevenbergMarquardtTest::test_scaled_meyer_x327
  /build/source/jaxopt/_src/levenberg_marquardt.py:507: UserWarning: The linear solver inv that requires materialization of J^T.J matrix is used with materialize_jac=False, which may cause a computational overhead. Consider using either a matrix-free iterative solver such as cg or bicg or using materialize_jac=True.
    warnings.warn(f"The linear solver {self.solver} that requires materialization of "
tests/loss_test.py::LossTest::test_multiclass_logistic_loss
tests/loss_test.py::LossTest::test_multiclass_sparsemax_loss
  /nix/store/kay9rbfsfmi0mlp7f19xqxyykk2kb00b-python3.11-jax-0.4.23/lib/python3.11/site-packages/jax/_src/ops/scatter.py:96: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.
    warnings.warn(
tests/levenberg_marquardt_test.py::LevenbergMarquardtTest::test_scaled_meyer_x325
  /build/source/jaxopt/_src/levenberg_marquardt.py:507: UserWarning: The linear solver lu that requires materialization of J^T.J matrix is used with materialize_jac=False, which may cause a computational overhead. Consider using either a matrix-free iterative solver such as cg or bicg or using materialize_jac=True.
    warnings.warn(f"The linear solver {self.solver} that requires materialization of "
tests/common_test.py::CommonTest::test_dtype_consistency
  /nix/store/kay9rbfsfmi0mlp7f19xqxyykk2kb00b-python3.11-jax-0.4.23/lib/python3.11/site-packages/jax/_src/lax/lax.py:2385: RuntimeWarning: overflow encountered in cast
    out = np.array(c, eqn.params['new_dtype'])
tests/scipy_wrappers_test.py::ScipyMinimizeTest::test_no_njev0
  /build/source/jaxopt/_src/scipy_wrappers.py:343: RuntimeWarning: Method Nelder-Mead does not use gradient information (jac).
    res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),
tests/scipy_wrappers_test.py::ScipyMinimizeTest::test_no_njev1
  /build/source/jaxopt/_src/scipy_wrappers.py:343: RuntimeWarning: Method Powell does not use gradient information (jac).
    res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),
tests/scipy_wrappers_test.py::ScipyMinimizeTest::test_no_njev2
  /build/source/jaxopt/_src/scipy_wrappers.py:343: OptimizeWarning: Unknown solver options: maxiter
    res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),
tests/levenberg_marquardt_test.py::LevenbergMarquardtTest::test_scaled_meyer_x324
  /build/source/jaxopt/_src/levenberg_marquardt.py:507: UserWarning: The linear solver cholesky that requires materialization of J^T.J matrix is used with materialize_jac=False, which may cause a computational overhead. Consider using either a matrix-free iterative solver such as cg or bicg or using materialize_jac=True.
    warnings.warn(f"The linear solver {self.solver} that requires materialization of "
tests/scipy_wrappers_test.py::ScipyRootFindingTest::test_broyden
  /nix/store/1z0wr5pb0ckj88qy92mwh7zkc0yaym80-python3.11-scipy-1.12.0/lib/python3.11/site-packages/scipy/optimize/_root.py:245: RuntimeWarning: Method broyden1 does not use the jacobian (jac).
    _warn_jac_unused(jac, method)
tests/mirror_descent_test.py::MirrorDescentTest::test_multiclass_svm_dual_implicit_diff_kl_stable
tests/mirror_descent_test.py::MirrorDescentTest::test_multiclass_svm_dual_implicit_diff_kl_stable
tests/mirror_descent_test.py::MirrorDescentTest::test_multiclass_svm_dual_implicit_diff_kl_stable
  /nix/store/ndvyzqskd5yqzybwfpqk1dyc9qp2k00f-python3.11-scikit-learn-1.4.0/lib/python3.11/site-packages/sklearn/svm/_base.py:1237: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
    warnings.warn(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ============================
FAILED tests/lbfgs_test.py::LbfgsTest::test_binary_logit_log_likelihood - AssertionError: 
FAILED tests/linear_solve_test.py::LinearSolveTest::test_solve_sparse - AssertionError: 
FAILED tests/polyak_sgd_test.py::PolyakSgdTest::test_logreg_with_intercept_manual_loop3 - AssertionError: Array(0.02369377, dtype=float32) not less than or equal to ...
============ 3 failed, 552 passed, 6 skipped, 33 warnings in 49.90s ============