google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.71k stars 194 forks source link

Relax absolute tolerance for failing tests involving chex.assert_trees_all_close. #1069

Closed carlosgmartin closed 2 months ago

vroulet commented 2 months ago

Thanks @carlosgmartin ! I don't see when or where these tests broke. Could you point out the reason for this change?

carlosgmartin commented 2 months ago

@vroulet Here's the output I get from running tests on my machine (after applying https://github.com/google-deepmind/optax/pull/1068 and https://github.com/google-deepmind/optax/pull/1071):

``` (venv) $ sh ./test.sh; tput bel ------------------------------------------------------------------- Your code has been rated at 10.00/10 (previous run: 1.41/10, +8.59) ------------------------------------ Your code has been rated at 10.00/10 Collecting build Using cached build-1.2.2-py3-none-any.whl.metadata (6.2 kB) Requirement already satisfied: packaging>=19.1 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from build) (24.1) Collecting pyproject_hooks (from build) Using cached pyproject_hooks-1.1.0-py3-none-any.whl.metadata (1.3 kB) Using cached build-1.2.2-py3-none-any.whl (22 kB) Using cached pyproject_hooks-1.1.0-py3-none-any.whl (9.2 kB) Installing collected packages: pyproject_hooks, build Successfully installed build-1.2.2 pyproject_hooks-1.1.0 * Creating isolated environment: venv+pip... * Installing packages in isolated environment: - flit_core >=3.2,<4 * Getting build dependencies for sdist... * Building sdist... Version number normalised: '0.2.4.dev' -> '0.2.4.dev0' (see PEP 440) * Building wheel from sdist * Creating isolated environment: venv+pip... * Installing packages in isolated environment: - flit_core >=3.2,<4 * Getting build dependencies for wheel... * Building wheel... Version number normalised: '0.2.4.dev' -> '0.2.4.dev0' (see PEP 440) Successfully built optax-0.2.4.dev0.tar.gz and optax-0.2.4.dev0-py3-none-any.whl Processing ./dist/optax-0.2.4.dev0.tar.gz Running command pip subprocess to install build dependencies Using pip 24.2 from /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/pip (python 3.12) Collecting flit_core<4,>=3.2 Obtaining dependency information for flit_core<4,>=3.2 from https://files.pythonhosted.org/packages/38/45/618e84e49a6c51e5dd15565ec2fcd82ab273434f236b8f108f065ded517a/flit_core-3.9.0-py3-none-any.whl.metadata Using cached flit_core-3.9.0-py3-none-any.whl.metadata (822 bytes) Using cached flit_core-3.9.0-py3-none-any.whl (63 kB) Installing collected packages: flit_core Successfully installed flit_core-3.9.0 Installing build dependencies ... done Running command Getting requirements to build wheel Getting requirements to build wheel ... done Running command Preparing metadata (pyproject.toml) Version number normalised: '0.2.4.dev' -> '0.2.4.dev0' (see PEP 440) Preparing metadata (pyproject.toml) ... done Building wheels for collected packages: optax Running command Building wheel for optax (pyproject.toml) Version number normalised: '0.2.4.dev' -> '0.2.4.dev0' (see PEP 440) Building wheel for optax (pyproject.toml) ... done Created wheel for optax: filename=optax-0.2.4.dev0-py3-none-any.whl size=302990 sha256=24c29cdf1a9e11eb65d0ec26ec9571a991c1fbacf5bd50f908d66c779f452462 Stored in directory: /Users/carlos/Library/Caches/pip/wheels/12/8b/12/f73f3c3c2e327a62f3b987cf82f38ac6de00e4d01f1a02e9cd Successfully built optax Processing ./optax-0.2.4.dev0-py3-none-any.whl Requirement already satisfied: absl-py>=0.7.1 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from optax==0.2.4.dev0) (1.4.0) Requirement already satisfied: chex>=0.1.86 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from optax==0.2.4.dev0) (0.1.86) Requirement already satisfied: jax>=0.4.27 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from optax==0.2.4.dev0) (0.4.33) Requirement already satisfied: jaxlib>=0.4.27 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from optax==0.2.4.dev0) (0.4.33) Requirement already satisfied: numpy>=1.18.0 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from optax==0.2.4.dev0) (1.26.4) Requirement already satisfied: etils[epy] in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from optax==0.2.4.dev0) (1.9.4) Requirement already satisfied: typing-extensions>=4.2.0 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from chex>=0.1.86->optax==0.2.4.dev0) (4.12.2) Requirement already satisfied: toolz>=0.9.0 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from chex>=0.1.86->optax==0.2.4.dev0) (0.12.1) Requirement already satisfied: setuptools in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from chex>=0.1.86->optax==0.2.4.dev0) (75.1.0) Requirement already satisfied: ml-dtypes>=0.2.0 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from jax>=0.4.27->optax==0.2.4.dev0) (0.4.1) Requirement already satisfied: opt-einsum in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from jax>=0.4.27->optax==0.2.4.dev0) (3.3.0) Requirement already satisfied: scipy>=1.10 in ./_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages (from jax>=0.4.27->optax==0.2.4.dev0) (1.14.1) Installing collected packages: optax Successfully installed optax-0.2.4.dev0 Computing dependencies Analyzing 48 sources with 20 local dependencies ninja: Entering directory `.pytype' [68/68] check constrain Leaving directory '.pytype' Success: no errors found =============================== test session starts ================================ platform darwin -- Python 3.12.3, pytest-8.3.3, pluggy-1.5.0 rootdir: /Users/carlos/Desktop/optax configfile: pyproject.toml plugins: xdist-3.6.1 16 workers [3564 items] ss.....s..s........s...s.......s.......s..s......................ss......... [ 2%] .....s..ss....s......s....s...................s......s...s....s.......s..... [ 4%] .........s.....ss....s.s.....s.....s...s.s....s....s...s.s..s...........s... [ 6%] ....s....s.s.s..s...s........s..s......s........s...s...s.s....s....s....s.. [ 8%] .s..s.s...s.....s.s....s.s.....s....s..s.....s..s......s.......s...s......s. [ 10%] ..s...s...ss.....s.s..s....s........s....s.....s.s...s...s.s..s.......s...ss [ 12%] ...s..........s.....s.s.......s.s.....s....s...s....s....s...s....s.s......s [ 14%] .s.............s.s........s.s...s...sss.s....s....s...s.......s.........s.s. [ 17%] ....s.s..s.s..s.........s......s....s.........s..s...s.s...s....s........s.. [ 19%] s....s....s.......s..s.....s.....s.s........s....s.s...s....s..s..........s. [ 21%] ...s.....s.....s.s....s.....s.....ss.......ss.s.......s..s.....s............ [ 23%] .....ss.........s..s..s....s....s..s....s............s....s.....s.s......ss. [ 25%] ........s..........s..ss.........s...s....s......ss.s....s......s.....s..... [ 27%] s..........s......s......s....s...s......s....s...........s......s....s..... [ 29%] .....s.....s......s..s.......s.........s......s.s.s............s...s..ss.... [ 31%] ..s..s..............s......ss.....s.....s...s..s........s...........s..s.... [ 34%] ....s...s.........s.....s......s.......s..................s..s..s........... [ 36%] ..s.............s..............s..s........s......s......................s.. [ 38%] s................s................s..s......s....s....s....s......s......... [ 40%] .s........s....s..s...........s...s..................s.s..s...s..s.......... [ 42%] ...s........s.s...s..s..........s...s.s...........s.........ss....s......... [ 44%] ...............s..ss....s...s....s.s.s.s..........s............s..s......... [ 46%] ..s.s....s...s...s........s...s..s.....s.s.s.....s.s..s.s...........s....s.. [ 49%] s.........s......s.......s.......s......s....s......s..s...s.....s.......... [ 51%] ..s.......s....s..s..s.......s...s.s.....sssss.s......s.s....s.............s [ 53%] ..s...s.....s...s.....s....s.s........s...........s...s.....s.....s.s....s.. [ 55%] ..s.....s.....s..s.s.....s.....s...s...s.........s......s...s.......s....... [ 57%] ...s.s.......s..ss....s.......s......s........s....s.s.........s...s...s.... [ 59%] ...s.....s.s..............s.s......s......s....s.............s.............. [ 61%] ..s....s.....s....s..........s.....s.s............s..s.s...........s...s.... [ 63%] s.s.........s.....s..........s.........s....s..........s.s......s...s....s.. [ 66%] ...s.......s...s......s....s........s.....s.....s............s.s.....s...s.. [ 68%] ..s....s....s........s.....s.sF........s..............s.....s.......s.s...s. [ 70%] ........s........s.......s.s..........s.....s.....s.........s........s...s.. [ 72%] ...s..s.........s.........s....s....s.s........s...............s...........s [ 74%] .s.........s.s...s........s.s..s..s.s..ss.....s.....ss.......s..........s.s. [ 76%] ..s....s......s....s.s.....s...s........s..s.....s.s.s.......s.......s.....s [ 78%] .......s...........s...s........s....ss.s.......s....s.s.s.....s....s....... [ 81%] ..........s....s.....s.......s.s...s....s.s....s....s..s...s....s....s...... [ 83%] ...s.......ss.........s.....s.....s...sss.....s.....s.....s........s......s. [ 85%] .s....s......s.....s.......s..s.....s.s...s...s.s.....s....s.....s......s..s [ 87%] ....s....s.s...............s...s.s........s....s...........s....s....s....s. [ 89%] s..........s.s...s.s......s...s...s...s....s.s......s..s.s............s.s... [ 91%] ....s..s....ss.............s......s....s...s.ss..s...s..........s.....s..s.. [ 93%] ....s.s.....s......s......s......s......s.....s.....s.s........s.s.......ss. [ 95%] ......ss.s.....s............s............................................... [ 98%] ..................s...........s.........F................s........s. [100%] ===================================== FAILURES ===================================== _______________________ ZoomLinesearchTest.test_linesearch7 ________________________ [gw1] darwin -- Python 3.12.3 /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/bin/python3 self = problem_name = 'rosenbrock', seed = 1 @parameterized.product( problem_name=[ 'polynomial', 'exponential', 'sinusoidal', 'rosenbrock', 'himmelblau', 'matyas', 'eggholder' ], seed=[0, 1], ) def test_linesearch(self, problem_name: str, seed: int): """Test backtracking linesearch (single update step).""" # Fixed tolerances, we check the behavior in standard conditions slope_rtol = 1e-4 curv_rtol = 0.9 tol = 0. key = jrd.PRNGKey(seed) params_key, precond_key = jrd.split(key, 2) problem = get_problem(problem_name) fn, input_shape = problem['fn'], problem['input_shape'] init_params = jrd.normal(params_key, input_shape) precond_vec = jrd.uniform(precond_key, input_shape) # Mimics a preconditioning by a diagonal matrix with non-negative entries # (non-negativity ensures that we keep a descent direction) init_updates = -precond_vec*jax.grad(fn)(init_params) opt_args = dict( max_linesearch_steps=30, slope_rtol=slope_rtol, curv_rtol=curv_rtol, tol=tol, max_learning_rate=None ) opt = _linesearch.scale_by_zoom_linesearch(**opt_args) final_params, final_state = _run_linesearch( opt, fn, init_params, init_updates ) scipy_res = scipy_optimize.line_search( fn, jax.grad(fn), init_params, init_updates ) with self.subTest('Check value and grad in zoom state'): self._check_value_and_grad_in_zoom_state( final_params, final_state, value_fn=fn ) with self.subTest('Check linesearch conditions'): self._check_linesearch_conditions( fn, init_params, init_updates, final_params, final_state, opt_args ) with self.subTest('Check against scipy'): stepsize = otu.tree_get(final_state, 'learning_rate') final_value = otu.tree_get(final_state, 'value') chex.assert_trees_all_close(scipy_res[0], stepsize, atol=1e-5) > chex.assert_trees_all_close(scipy_res[3], final_value, atol=1e-5) optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/_src/linesearch_test.py:483: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ optax-env.TZoLx2zo/lib/python3.12/site-packages/chex/_src/asserts_internal.py:278: in _chex_assert_fn host_assertion_fn( _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ custom_message = None, custom_message_format_vars = () include_default_message = True, exception_type = args = (Array(3211.6665, dtype=float32), Array(3211.6628, dtype=float32)) kwargs = {'atol': 1e-05} assertion_exc = AssertionError('[Chex] Assertion assert_trees_all_equal_comparator failed: Trees (arrays) 0 and 1 differ: \nNot equal ...34e-06\n x: array(3211.6665, dtype=float32)\n y: array(3211.6628, dtype=float32) \nOriginal dtypes: float32, float32.') value_exc = None error_msg = '[Chex] Assertion assert_trees_all_close failed: Trees (arrays) 0 and 1 differ: \nNot equal to tolerance rtol=1e-06, ...534e-06\n x: array(3211.6665, dtype=float32)\n y: array(3211.6628, dtype=float32) \nOriginal dtypes: float32, float32.' default_msg = 'Assertion assert_trees_all_close failed: ' def _assert_on_host(*args, custom_message: Optional[str] = None, custom_message_format_vars: Sequence[Any] = (), include_default_message: bool = True, exception_type: Type[Exception] = AssertionError, **kwargs) -> None: # Format error's stack trace to remove Chex' internal frames. assertion_exc = None value_exc = None try: assert_fn(*args, **kwargs) except AssertionError as e: assertion_exc = e except ValueError as e: value_exc = e finally: if value_exc is not None: raise ValueError(str(value_exc)) if assertion_exc is not None: # Format the exception message. error_msg = str(assertion_exc) # Include only the name of the outermost chex assertion. if error_msg.startswith(ERR_PREFIX): error_msg = error_msg[error_msg.find("failed:") + len("failed:"):] # Whether to include the default error message. default_msg = (f"Assertion {name} failed: " if include_default_message else "") error_msg = f"{ERR_PREFIX}{default_msg}{error_msg}" # Whether to include a custom error message. if custom_message: if custom_message_format_vars: custom_message = custom_message.format(*custom_message_format_vars) error_msg = f"{error_msg} [{custom_message}]" > raise exception_type(error_msg) E AssertionError: [Chex] Assertion assert_trees_all_close failed: Trees (arrays) 0 and 1 differ: E Not equal to tolerance rtol=1e-06, atol=1e-05 E Error in value equality check: Values not approximately equal E Mismatched elements: 1 / 1 (100%) E Max absolute difference: 0.00366211 E Max relative difference: 1.1402534e-06 E x: array(3211.6665, dtype=float32) E y: array(3211.6628, dtype=float32) E Original dtypes: float32, float32. optax-env.TZoLx2zo/lib/python3.12/site-packages/chex/_src/asserts_internal.py:196: AssertionError _______________________ LBFGSTest.test_plain_preconditioning _______________________ [gw12] darwin -- Python 3.12.3 /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/bin/python3 self = def test_plain_preconditioning(self): key = jrd.PRNGKey(0) key_ws, key_us, key_vec = jrd.split(key, 3) m = 4 d = 3 dws = jrd.normal(key_ws, (m, d)) dus = jrd.normal(key_us, (m, d)) rhos = 1.0 / jnp.sum(dws * dus, axis=1) vec = jrd.normal(key_vec, (d,)) plain_precond_vec = _plain_preconditioning(dws, dus, vec) precond_mat = _materialize_approx_inv_hessian(dws, dus, rhos, memory_idx=0) expected_precond_vec = precond_mat.dot( vec, precision=jax.lax.Precision.HIGHEST ) > chex.assert_trees_all_close(plain_precond_vec, expected_precond_vec) optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/_src/alias_test.py:589: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ optax-env.TZoLx2zo/lib/python3.12/site-packages/chex/_src/asserts_internal.py:278: in _chex_assert_fn host_assertion_fn( _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ custom_message = None, custom_message_format_vars = () include_default_message = True, exception_type = args = (Array([ 146.7743 , 82.296425, -844.90216 ], dtype=float32), Array([ 146.77422, 82.29634, -844.9018 ], dtype=float32)) kwargs = {} assertion_exc = AssertionError('[Chex] Assertion assert_trees_all_equal_comparator failed: Trees (arrays) 0 and 1 differ: \nNot equal ..., dtype=float32)\n y: array([ 146.77422, 82.29634, -844.9018 ], dtype=float32) \nOriginal dtypes: float32, float32.') value_exc = None error_msg = '[Chex] Assertion assert_trees_all_close failed: Trees (arrays) 0 and 1 differ: \nNot equal to tolerance rtol=1e-06, ...], dtype=float32)\n y: array([ 146.77422, 82.29634, -844.9018 ], dtype=float32) \nOriginal dtypes: float32, float32.' default_msg = 'Assertion assert_trees_all_close failed: ' def _assert_on_host(*args, custom_message: Optional[str] = None, custom_message_format_vars: Sequence[Any] = (), include_default_message: bool = True, exception_type: Type[Exception] = AssertionError, **kwargs) -> None: # Format error's stack trace to remove Chex' internal frames. assertion_exc = None value_exc = None try: assert_fn(*args, **kwargs) except AssertionError as e: assertion_exc = e except ValueError as e: value_exc = e finally: if value_exc is not None: raise ValueError(str(value_exc)) if assertion_exc is not None: # Format the exception message. error_msg = str(assertion_exc) # Include only the name of the outermost chex assertion. if error_msg.startswith(ERR_PREFIX): error_msg = error_msg[error_msg.find("failed:") + len("failed:"):] # Whether to include the default error message. default_msg = (f"Assertion {name} failed: " if include_default_message else "") error_msg = f"{ERR_PREFIX}{default_msg}{error_msg}" # Whether to include a custom error message. if custom_message: if custom_message_format_vars: custom_message = custom_message.format(*custom_message_format_vars) error_msg = f"{error_msg} [{custom_message}]" > raise exception_type(error_msg) E AssertionError: [Chex] Assertion assert_trees_all_close failed: Trees (arrays) 0 and 1 differ: E Not equal to tolerance rtol=1e-06, atol=0 E Error in value equality check: Values not approximately equal E Mismatched elements: 1 / 3 (33.3%) E Max absolute difference: 0.00036621 E Max relative difference: 1.01977e-06 E x: array([ 146.7743 , 82.296425, -844.90216 ], dtype=float32) E y: array([ 146.77422, 82.29634, -844.9018 ], dtype=float32) E Original dtypes: float32, float32. optax-env.TZoLx2zo/lib/python3.12/site-packages/chex/_src/asserts_internal.py:196: AssertionError ================================= warnings summary ================================= _testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/_src/alias_test.py: 6 warnings _testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/contrib/_schedule_free_test.py: 4 warnings /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:118: DeprecationWarning: Casting from complex to real dtypes will soon raise a ValueError. Please first use jnp.real or jnp.imag to take the real/imaginary component of your input. return lax_numpy.astype(self, dtype, copy=copy, device=device) _testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/_src/alias_test.py: 16 warnings /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/_src/transform.py:1521: UserWarning: Explicitly requested dtype float64 requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. lambda leaf: jnp.zeros((memory_size,) + leaf.shape, dtype=leaf.dtype), _testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/contrib/_common_test.py::ContribTest::test_optimizers_can_be_wrapped_in_inject_hyperparams_(opt_name='momo_adam', opt_kwargs={'learning_rate': 0.1}, wrapper_name=None, wrapper_kwargs=None)__without_device /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/contrib/_momo.py:301: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. bc2 = jnp.asarray(1 - b2 ** count_inc, dtype=barf.dtype) _testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/contrib/_common_test.py::ContribTest::test_optimizers_can_be_wrapped_in_inject_hyperparams_(opt_name='momo_adam', opt_kwargs={'learning_rate': 0.1}, wrapper_name=None, wrapper_kwargs=None)__without_device /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/contrib/_momo.py:310: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. bc1 = jnp.asarray(1 - b1 ** count_inc, dtype=barf.dtype) _testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/tree_utils/_tree_math_test.py::TreeUtilsTest::test_tree_add_scalar_mul /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype complex128 requested in astype is not available, and will be truncated to dtype complex64. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) _testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/tree_utils/_tree_math_test.py::TreeUtilsTest::test_tree_add_scalar_mul /Users/carlos/Desktop/optax/_testing/optax-env.TZoLx2zo/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. return lax_numpy.astype(self, dtype, copy=copy, device=device) -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ============================= short test summary info ============================== FAILED optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/_src/linesearch_test.py::ZoomLinesearchTest::test_linesearch7 - AssertionError: [Chex] Assertion assert_trees_all_close failed: Trees (arrays)... FAILED optax-env.TZoLx2zo/lib/python3.12/site-packages/optax/_src/alias_test.py::LBFGSTest::test_plain_preconditioning - AssertionError: [Chex] Assertion assert_trees_all_close failed: Trees (arrays)... ============ 2 failed, 2955 passed, 607 skipped, 30 warnings in 34.51s ============= (venv) $ ```

Do you get the same output from running tests on your machine?

vroulet commented 2 months ago

Indeed, I get the error too. I don't understand why it's not been caught on the tests in github. Anyway, I'd prefer using small relative accuracies rather than slightly loose absolute differences (the functions that these tests verify are supposed to be use for high precision purposes). So I'd prefer to use for the lbfgs test:

    chex.assert_trees_all_close(
        plain_precond_vec, expected_precond_vec, rtol=1e-5
    )

and for the linesearch test:

      chex.assert_trees_all_close(scipy_res[0], stepsize, rtol=1e-5)
      chex.assert_trees_all_close(scipy_res[3], final_value, rtol=1e-5)

I checked, the tests pass with these changes.

vroulet commented 2 months ago

The doctest of _dog also fails on my end. As you are on it, could you also use ellipses on the doctest of dog, i.e., here use

    Objective function:  13.99...
    Objective function:  13.99...
    Objective function:  13.99...
    Objective function:  13.99...
    Objective function:  13.99...

(This algorithm needs to be completed with a proper evaluation function and fuse with a similar one in the alias/transform files.)