blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

DeprecationWarning from jax 0.4.27 #663

Closed GaetanLepage closed 4 months ago

GaetanLepage commented 5 months ago

Describe the issue as clearly as possible:

Since jax 0.4.27, the following tests fail:

``` =========================== short test summary info ============================ FAILED tests/adaptation/test_step_size.py::StepSizeTest::test_reasonable_step_size__with_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/adaptation/test_step_size.py::StepSizeTest::test_reasonable_step_size__without_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/adaptation/test_step_size.py::StepSizeTest::test_reasonable_step_size__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/adaptation/test_step_size.py::StepSizeTest::test_reasonable_step_size__without_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_random_walk_without_chex.py::IRMHTest::test_non_symmetric_proposal - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_random_walk_without_chex.py::IRMHTest::test_proposal_is_independent_of_position - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/adaptation/test_adaptation.py::test_chees_adaptation - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_barker__without_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_barker__with_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_hmc__without_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::LatentGaussianTest::test_latent_gaussian__without_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_irmh__with_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_mala__with_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_hmc__with_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_barker__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_barker__without_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_mala__without_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_irmh__without_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_ghmc__with_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_ghmc__without_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_nuts__with_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_hmc__without_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_nuts__without_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_mala__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_irmh__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::LatentGaussianTest::test_latent_gaussian__without_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_rmh__with_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_random_walk_without_chex.py::AdditiveStepTest::test_one_step_addition - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_random_walk__without_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_hmc__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_mala__without_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_irmh__without_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_ghmc__without_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_ghmc__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_rmh__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_nuts__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_nuts__without_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_rmh__without_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_random_walk__without_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_latent_gaussian.py::GaussianTest::test_gaussian2 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_trajectory.py::TrajectoryTest::test_dynamic_progressive_expansion_(1e-10, False, False, 10)__without_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_random_walk__with_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_trajectory.py::TrajectoryTest::test_dynamic_progressive_expansion_(1, False, True, 2)__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_trajectory.py::TrajectoryTest::test_dynamic_progressive_expansion_(1, False, True, 2)__without_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_rmh__without_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::LinearRegressionTest::test_window_adaptation1 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::LinearRegressionTest::test_barker - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_rmhmc__with_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_trajectory.py::TrajectoryTest::test_dynamic_progressive_expansion_(100000, True, True, 1)__without_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_latent_gaussian.py::GaussianTest::test_gaussian3 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::LinearRegressionTest::test_mala - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_rmhmc__without_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_trajectory.py::TrajectoryTest::test_dynamic_progressive_expansion_(100000, True, True, 1)__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_random_walk__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_latent_gaussian.py::GaussianTest::test_gaussian0 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_trajectory.py::TrajectoryTest::test_dynamic_progressive_expansion_(100000, True, True, 1)__with_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::LinearRegressionTest::test_window_adaptation3 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_trajectory.py::TrajectoryTest::test_dynamic_progressive_expansion_(1, False, True, 2)__without_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_trajectory.py::TrajectoryTest::test_dynamic_progressive_expansion_(1e-10, False, False, 10)__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_trajectory.py::TrajectoryTest::test_dynamic_progressive_expansion_(1e-10, False, False, 10)__with_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_rmhmc__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::LinearRegressionTest::test_window_adaptation2 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::UnivariateNormalTest::test_rmhmc__without_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_trajectory.py::TrajectoryTest::test_dynamic_progressive_expansion_(100000, True, True, 1)__without_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_latent_gaussian.py::GaussianTest::test_gaussian1 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::LinearRegressionTest::test_chees0 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_trajectory.py::TrajectoryTest::test_dynamic_progressive_expansion_(1e-10, False, False, 10)__without_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::LatentGaussianTest::test_latent_gaussian__with_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::LinearRegressionTest::test_chees1 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::LatentGaussianTest::test_latent_gaussian__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_inner_kernel_tuning.py::InnerKernelTuningJitTest::test_with_tempered_smc__without_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_inner_kernel_tuning.py::InnerKernelTuningJitTest::test_with_adaptive_tempered__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_inner_kernel_tuning.py::InnerKernelTuningJitTest::test_with_adaptive_tempered__without_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_inner_kernel_tuning.py::InnerKernelTuningJitTest::test_with_tempered_smc__without_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_inner_kernel_tuning.py::InnerKernelTuningJitTest::test_with_adaptive_tempered__without_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::MonteCarloStandardErrorTest::test_mcse0 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_inner_kernel_tuning.py::InnerKernelTuningJitTest::test_with_tempered_smc__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::MonteCarloStandardErrorTest::test_mcse2 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_inner_kernel_tuning.py::SMCParameterTuningTest::test_smc_inner_kernel_tempered - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::LinearRegressionTest::test_meads - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::MonteCarloStandardErrorTest::test_mcse1 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::MonteCarloStandardErrorTest::test_mcse4 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::MonteCarloStandardErrorTest::test_mcse3 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_kernel_compatibility.py::SMCAndMCMCIntegrationTest::test_compatible_with_irmh - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_kernel_compatibility.py::SMCAndMCMCIntegrationTest::test_compatible_with_rmh - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_trajectory.py::TrajectoryTest::test_dynamic_hmc_integration_steps - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_smc.py::SMCTest::test_smc_waste_free__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_inner_kernel_tuning.py::InnerKernelTuningJitTest::test_with_tempered_smc__with_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_kernel_compatibility.py::SMCAndMCMCIntegrationTest::test_compatible_with_hmc - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_kernel_compatibility.py::SMCAndMCMCIntegrationTest::test_compatible_with_rwm - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_kernel_compatibility.py::SMCAndMCMCIntegrationTest::test_compatible_with_nuts - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_smc.py::SMCTest::test_smc__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_tempered_smc.py::TemperedSMCTest::test_fixed_schedule_tempered_smc__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/test_compilation.py::CompilationTest::test_nuts - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/test_compilation.py::CompilationTest::test_hmc_warmup - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/test_compilation.py::CompilationTest::test_hmc - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/test_compilation.py::CompilationTest::test_nuts_warmup - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_tempered_smc.py::NormalizingConstantTest::test_normalizing_constant__with_jit - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::LinearRegressionTest::test_pathfinder_adaptation1 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_inner_kernel_tuning.py::SMCParameterTuningTest::test_smc_inner_kernel_adaptive_tempered - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::LinearRegressionTest::test_window_adaptation0 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_inner_kernel_tuning.py::InnerKernelTuningJitTest::test_with_adaptive_tempered__with_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/test_util.py::RunInferenceAlgorithmTest::test_compatible_with_initial_pos0 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/test_util.py::RunInferenceAlgorithmTest::test_compatible_with_initial_state1 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_sampling.py::LinearRegressionTest::test_pathfinder_adaptation0 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/test_util.py::RunInferenceAlgorithmTest::test_compatible_with_initial_state0 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/test_util.py::RunInferenceAlgorithmTest::test_compatible_with_initial_pos1 - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/smc/test_kernel_compatibility.py::SMCAndMCMCIntegrationTest::test_compatible_with_mala - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... FAILED tests/mcmc/test_trajectory.py::TrajectoryTest::test_dynamic_progressive_expansion_(1, False, True, 2)__with_device - DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy... ============ 109 failed, 426 passed, 1 skipped in 115.04s (0:01:55) ============ ```

Steps/code to reproduce the bug:

pytest

Expected result:

All tests pass.

Error message:

E       DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy.clip is deprecated. Please use 'x', 'min', and 'max' respectively instead.
E       --------------------
E       For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Blackjax/JAX/jaxlib/Python version information:

BlackJAX 1.2.0
Python 3.11.9 (main, Apr  2 2024, 08:25:04) [GCC 13.2.0]
Jax 0.4.27
Jaxlib 0.4.27

Context for the issue:

Updating jax in nixpkgs: https://github.com/NixOS/nixpkgs/pull/291705

junpenglao commented 4 months ago

Cut a new release https://github.com/blackjax-devs/blackjax/releases/tag/1.2.1, should work now

GaetanLepage commented 4 months ago

Thanks so much !