JaxGaussianProcesses / GPJax

Gaussian processes in JAX.
https://docs.jaxgaussianprocesses.com/
Apache License 2.0
436 stars 51 forks source link

test(test_gps.py): fix jaxtyping.TypeCheckError #437

Closed stephen-huan closed 5 months ago

stephen-huan commented 6 months ago

Type of changes

Checklist

Description

My pytest header is as follows.

platform linux -- Python 3.11.7, pytest-7.4.4, pluggy-1.4.0
rootdir: ...
configfile: pyproject.toml
plugins: jaxtyping-0.2.25, typeguard-4.1.5

Running pytest tests/test_gps.py gives

============================================================================== short test summary info ===============================================================================
FAILED tests/test_gps.py::test_prior_sample_approx[mean_function0-RBF-1] - jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of sample_approx.
FAILED tests/test_gps.py::test_prior_sample_approx[mean_function0-RBF-5] - jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of sample_approx.
FAILED tests/test_gps.py::test_prior_sample_approx[mean_function0-Matern52-1] - jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of sample_approx.
FAILED tests/test_gps.py::test_prior_sample_approx[mean_function0-Matern52-5] - jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of sample_approx.
FAILED tests/test_gps.py::test_prior_sample_approx[mean_function1-RBF-1] - jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of sample_approx.
FAILED tests/test_gps.py::test_prior_sample_approx[mean_function1-RBF-5] - jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of sample_approx.
FAILED tests/test_gps.py::test_prior_sample_approx[mean_function1-Matern52-1] - jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of sample_approx.
FAILED tests/test_gps.py::test_prior_sample_approx[mean_function1-Matern52-5] - jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of sample_approx.
FAILED tests/test_gps.py::test_conjugate_posterior_sample_approx[mean_function0-RBF-1] - jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of sample_approx.
FAILED tests/test_gps.py::test_conjugate_posterior_sample_approx[mean_function0-RBF-5] - jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of sample_approx.
FAILED tests/test_gps.py::test_conjugate_posterior_sample_approx[mean_function0-Matern52-1] - jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of sample_approx.
FAILED tests/test_gps.py::test_conjugate_posterior_sample_approx[mean_function0-Matern52-5] - jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of sample_approx.
FAILED tests/test_gps.py::test_conjugate_posterior_sample_approx[mean_function1-RBF-1] - jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of sample_approx.
FAILED tests/test_gps.py::test_conjugate_posterior_sample_approx[mean_function1-RBF-5] - jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of sample_approx.
FAILED tests/test_gps.py::test_conjugate_posterior_sample_approx[mean_function1-Matern52-1] - jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of sample_approx.
FAILED tests/test_gps.py::test_conjugate_posterior_sample_approx[mean_function1-Matern52-5] - jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of sample_approx.
=========================================================================== 16 failed, 42 passed in 3.83s ============================================================================

This patch fixes the test by ignoring the warnings from jaxtyping like how beartype is already ignored.

This could be a consequence of my (updated) dependencies since the test passes on CI.

P.S. flax is declared as a dev and docs dependency but it's used in the tests so it should be also added to test.

stephen-huan commented 5 months ago

addressed by #442