Open ghost opened 9 months ago
jaxopt 0.8.3 with patch is failing the following unit tests after updating to scipy 1.12
=================================== FAILURES =================================== __________________ LbfgsTest.test_binary_logit_log_likelihood __________________ [gw45] linux -- Python 3.11.7 /nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/bin/python3.11 self = <lbfgs_test.LbfgsTest testMethod=test_binary_logit_log_likelihood> def test_binary_logit_log_likelihood(self): # See issue #409 rng = jax.random.PRNGKey(42) N = 1000 beta = jnp.array([[0.5,0.5]]).T income = jax.random.normal(rng, shape=(N,1)) x = jnp.hstack([jnp.ones((N,1)), income]) def simulate_binary_logit(x, beta): beta = beta.reshape(-1,1) N = x.shape[0] J = beta.shape[0] epsilon = jax.random.gumbel(rng,shape =(N,J)) Beta_augmented = jnp.hstack([beta, jnp.zeros_like(beta)]) utility = x @ Beta_augmented + epsilon choice_idx = onp.argmax(utility, axis=1) return (choice_idx).reshape(-1,1) y = simulate_binary_logit(x, beta) y = jnp.ravel(y) # numpy version def binary_logit_log_likelihood(beta, y,x): lambda_xb = onp.exp(x@beta) / (1 + onp.exp(x@beta)) ll_i = y * onp.log(lambda_xb) + (1-y) * onp.log(1-lambda_xb) ll = -onp.sum(ll_i) return ll # jax version def binary_logit_log_likelihood_jax(beta, y, x): lambda_xb = jnp.exp(x@beta) / (1 + jnp.exp(x@beta)) ll_i = y * jnp.log(lambda_xb) + (1-y) * jnp.log(1-lambda_xb) ll = -jnp.sum(ll_i) return ll beta_init = jnp.array([0.01,0.01]) # using scipy scipy_res = scipy_opt.minimize( fun=binary_logit_log_likelihood, args=(onp.asarray(y),onp.asarray(x)), x0 = (onp.asarray(beta_init)), method='BFGS' ).x # using jaxopt solver = LBFGS(fun=binary_logit_log_likelihood_jax, maxiter=100, linesearch="zoom", maxls=10, tol=1e-12) jaxopt_res =, y, x).params # comparison scipy_val = binary_logit_log_likelihood(scipy_res, onp.asarray(y), onp.asarray(x)) jaxopt_val = binary_logit_log_likelihood(jaxopt_res, y, x) > self.assertArraysAllClose(scipy_val, jaxopt_val) tests/ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ jaxopt/_src/ in assertArraysAllClose _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg) jaxopt/_src/ in _assert_numpy_allclose onp.testing.assert_allclose(a, b, **kw, err_msg=err_msg) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ args = (<function assert_allclose.<locals>.compare at 0xfffeb82b56c0>, array(636.76217796), array(636.7615, dtype=float32)) kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-06, atol=1e-06', 'verbose': True} @wraps(func) def inner(*args, **kwds): with self._recreate_cm(): > return func(*args, **kwds) E AssertionError: E Not equal to tolerance rtol=1e-06, atol=1e-06 E E Mismatched elements: 1 / 1 (100%) E Max absolute difference: 0.00070335 E Max relative difference: 1.10457124e-06 E x: array(636.762178) E y: array(636.7615, dtype=float32) /nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/lib/python3.11/ AssertionError ______________________ LinearSolveTest.test_solve_sparse _______________________ [gw32] linux -- Python 3.11.7 /nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/bin/python3.11 self = <linear_solve_test.LinearSolveTest testMethod=test_solve_sparse> def test_solve_sparse(self): rng = onp.random.RandomState(0) # Matrix case. A = rng.randn(5, 5) b = rng.randn(5) def matvec(x): return, x) x = linear_solve.solve_lu(matvec, b) x2 = linear_solve.solve_normal_cg(matvec, b) x3 = linear_solve.solve_gmres(matvec, b) x4 = linear_solve.solve_bicgstab(matvec, b) x5 = linear_solve.solve_iterative_refinement(matvec, b) x6 = linear_solve.solve_qr(matvec, b) self.assertArraysAllClose(x, x2, atol=1e-4) > self.assertArraysAllClose(x, x3, atol=1e-4) tests/ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ jaxopt/_src/ in assertArraysAllClose _assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg) jaxopt/_src/ in _assert_numpy_allclose onp.testing.assert_allclose(a, b, **kw, err_msg=err_msg) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ args = (<function assert_allclose.<locals>.compare at 0xfffe946f58a0>, array([-6.9443398, -1.9871643, 7.747069 , 7.654946 ,...875 ], dtype=float32), array([-6.944449 , -1.9872042, 7.747199 , 7.655077 , -7.038865 ], dtype=float32)) kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-06, atol=0.0001', 'verbose': True} @wraps(func) def inner(*args, **kwds): with self._recreate_cm(): > return func(*args, **kwds) E AssertionError: E Not equal to tolerance rtol=1e-06, atol=0.0001 E E Mismatched elements: 4 / 5 (80%) E Max absolute difference: 0.00013113 E Max relative difference: 2.009613e-05 E x: array([-6.94434 , -1.987164, 7.747069, 7.654946, -7.03875 ], E dtype=float32) E y: array([-6.944449, -1.987204, 7.747199, 7.655077, -7.038865], E dtype=float32) /nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/lib/python3.11/ AssertionError ____________ PolyakSgdTest.test_logreg_with_intercept_manual_loop3 _____________ [gw13] linux -- Python 3.11.7 /nix/store/dz8lm4h0ivibad5kfc0ya3p3zqyd2fyf-python3-3.11.7/bin/python3.11 self = <polyak_sgd_test.PolyakSgdTest testMethod=test_logreg_with_intercept_manual_loop3> momentum = 0.9, sps_variant = 'SPS+' @parameterized.product(momentum=[0.0, 0.9], sps_variant=['SPS_max', 'SPS+']) def test_logreg_with_intercept_manual_loop(self, momentum, sps_variant): x, y = datasets.make_classification(n_samples=10, n_features=5, n_classes=3, n_informative=3, random_state=0) data = (x, y) l2reg = 0.1 # fun(params, l2reg, data) fun = objective.l2_multiclass_logreg_with_intercept n_classes = len(jnp.unique(y)) w_init = jnp.zeros((x.shape[1], n_classes)) b_init = jnp.zeros(n_classes) params = (w_init, b_init) opt = PolyakSGD( fun=fun, fun_min=0.6975, momentum=momentum, variant=sps_variant ) error_init = opt.l2_optimality_error(params, l2reg=l2reg, data=data) state = opt.init_state(params, l2reg=l2reg, data=data) for _ in range(200): params, state = opt.update(params, state, l2reg=l2reg, data=data) # Check optimality conditions. error = opt.l2_optimality_error(params, l2reg=l2reg, data=data) > self.assertLessEqual(error / error_init, 0.02) E AssertionError: Array(0.02369377, dtype=float32) not less than or equal to 0.02 tests/ AssertionError =============================== warnings summary =============================== jaxopt/_src/ /build/source/jaxopt/_src/ DeprecationWarning: invalid escape sequence '\m' """Operator Splitting Solver for Quadratic Programs. tests/ tests/ tests/ tests/ tests/ tests/ tests/ tests/ tests/ /build/source/jaxopt/_src/ UserWarning: Numba could not be imported. Code will run much more slowly. To install, run 'pip install numba'. warnings.warn( tests/ tests/ tests/ tests/ tests/ tests/ /build/source/jaxopt/_src/ UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See for more. fun = lambda leaf: jnp.zeros((history_size,) + leaf.shape, dtype=leaf.dtype) tests/ tests/ tests/ tests/ /build/source/jaxopt/_src/ UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See for more. x = jnp.zeros(shape, dtype) tests/ /build/source/jaxopt/_src/ UserWarning: The linear solver inv that requires materialization of J^T.J matrix is used with materialize_jac=False, which may cause a computational overhead. Consider using either a matrix-free iterative solver such as cg or bicg or using materialize_jac=True. warnings.warn(f"The linear solver {self.solver} that requires materialization of " tests/ tests/ /nix/store/kay9rbfsfmi0mlp7f19xqxyykk2kb00b-python3.11-jax-0.4.23/lib/python3.11/site-packages/jax/_src/ops/ FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error. warnings.warn( tests/ /build/source/jaxopt/_src/ UserWarning: The linear solver lu that requires materialization of J^T.J matrix is used with materialize_jac=False, which may cause a computational overhead. Consider using either a matrix-free iterative solver such as cg or bicg or using materialize_jac=True. warnings.warn(f"The linear solver {self.solver} that requires materialization of " tests/ /nix/store/kay9rbfsfmi0mlp7f19xqxyykk2kb00b-python3.11-jax-0.4.23/lib/python3.11/site-packages/jax/_src/lax/ RuntimeWarning: overflow encountered in cast out = np.array(c, eqn.params['new_dtype']) tests/ /build/source/jaxopt/_src/ RuntimeWarning: Method Nelder-Mead does not use gradient information (jac). res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype), tests/ /build/source/jaxopt/_src/ RuntimeWarning: Method Powell does not use gradient information (jac). res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype), tests/ /build/source/jaxopt/_src/ OptimizeWarning: Unknown solver options: maxiter res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype), tests/ /build/source/jaxopt/_src/ UserWarning: The linear solver cholesky that requires materialization of J^T.J matrix is used with materialize_jac=False, which may cause a computational overhead. Consider using either a matrix-free iterative solver such as cg or bicg or using materialize_jac=True. warnings.warn(f"The linear solver {self.solver} that requires materialization of " tests/ /nix/store/1z0wr5pb0ckj88qy92mwh7zkc0yaym80-python3.11-scipy-1.12.0/lib/python3.11/site-packages/scipy/optimize/ RuntimeWarning: Method broyden1 does not use the jacobian (jac). _warn_jac_unused(jac, method) tests/ tests/ tests/ /nix/store/ndvyzqskd5yqzybwfpqk1dyc9qp2k00f-python3.11-scikit-learn-1.4.0/lib/python3.11/site-packages/sklearn/svm/ ConvergenceWarning: Liblinear failed to converge, increase the number of iterations. warnings.warn( -- Docs: =========================== short test summary info ============================ FAILED tests/ - AssertionError: FAILED tests/ - AssertionError: FAILED tests/ - AssertionError: Array(0.02369377, dtype=float32) not less than or equal to ... ============ 3 failed, 552 passed, 6 skipped, 33 warnings in 49.90s ============
jaxopt 0.8.3 with patch is failing the following unit tests after updating to scipy 1.12