pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.09k stars 227 forks source link

filter out tests waiting for next tfp release #1817

Closed juanitorduz closed 2 weeks ago

juanitorduz commented 2 weeks ago

Partially addresses https://github.com/pyro-ppl/numpyro/issues/1814 . We must keep in mind removing these skip statements once we see a new release.

juanitorduz commented 2 weeks ago

There is a new test failing probably because the new jax and/or numpy releases

__________________ test_discrete_site_without_infer_enumerate __________________

    def test_discrete_site_without_infer_enumerate():
        def model():
            numpyro.sample("x", dist.Bernoulli(0.5))

        mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
        with pytest.warns(FutureWarning, match="enumerated sites"):
>           mcmc.run(random.PRNGKey(0))
E           FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.

test/infer/test_mcmc.py:1104: FutureWarning

I added a different match group in https://github.com/pyro-ppl/numpyro/pull/1817/commits/aa30c69b57184b6ce16d0ce19d2892186fd5e249 but I think it is essential to address these warnings. Especially because we are also getting

DeprecationWarning: numpy.core.numeric is deprecated and has been renamed to numpy._core.numeric. The numpy._core namespace contains private NumPy internals and its use is discouraged, as NumPy internals can change without warning in any release. In practice, most real-world usage of numpy.core is to access functionality in the public NumPy API. If that is the case, use the public NumPy API. If not, you are using NumPy internals. If you would still like to access an internal attribute, use numpy._core.numeric.normalize_axis_tuple.
    from numpy.core.numeric import normalize_axis_tuple

test/infer/test_mcmc.py::test_discrete_site_without_infer_enumerate
  /Users/juanitorduz/Documents/envs/numpyro-env/lib/python3.12/site-packages/jax/_src/linear_util.py:192: DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy.clip is deprecated. Please use 'x', 'min', and 'max' respectively instead.
juanitorduz commented 2 weeks ago

It seems

FutureWarning: unhashable type: <class 'jax._src.interpreters.batching.BatchTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.

is all over the place now 😓

fehiepsi commented 2 weeks ago

Thanks @juanitorduz. I'm looking at them.

juanitorduz commented 2 weeks ago

Thanks @juanitorduz. I'm looking at them.

ok! You can either push to this branch or create a new one if needed

fehiepsi commented 2 weeks ago

It turns out that in funsor, we have some checks for tracers to be Hashable. I don't think that the new behavior will cause issues: it is fine to let arrays to be either hashable or unhashable. So I think we can simply filter out these warnings:

warnings.filterwarnings("ignore", message=".Attempting to hash a tracer.", category=FutureWarning)


+ in `pyproject.toml`: `"ignore:.*Attempting to hash a tracer:FutureWarning",` for pytest
juanitorduz commented 2 weeks ago

ok! We are making progress 😅 ! Now we have

FAILED test/contrib/test_control_flow.py::test_scan - TypeError: body_fun output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', ['ShapedArray(float32[10])', 'ShapedArray(float32[10])', 'DIFFERENT ShapedArray(int32[], weak_type=True) vs. ShapedArray(float0[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])'], []).
FAILED test/contrib/test_control_flow.py::test_scan_svi - TypeError: body_fun output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', ['ShapedArray(float32[3,5])', 'DIFFERENT ShapedArray(int32[], weak_type=True) vs. ShapedArray(float0[])', 'ShapedArray(float32[5])'], []).

and

TypeError: body_fun output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', ['ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'DIFFERENT ShapedArray(int32[], weak_type=True) vs. ShapedArray(float0[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[10])'], []).
test/test_examples.py::test_cpu[holt_winters.py --T 4 --num-samples 10 --num-warmup 10 --num-chains 2] Running:
python examples/holt_winters.py --T 4 --num-samples 10 --num-warmup 10 --num-chains 2

I am so seeing

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

🤔

juanitorduz commented 2 weeks ago

It seems like a type problem ... could this be again a numpy or jax recent change?

fehiepsi commented 2 weeks ago

Hi @juanitorduz, I raised the upstream error in https://github.com/google/jax/issues/22045. For a fix, could you help me change every device_put(foo) in https://github.com/pyro-ppl/numpyro/blob/master/numpyro/contrib/control_flow/scan.py to

tree_map(device_put, foo)
juanitorduz commented 2 weeks ago

In https://github.com/pyro-ppl/numpyro/pull/1817/commits/19f82322060fd1b6ce31be1797f98b254610c2bf I still saw other test failing:

FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[1-MetropolisAdjustedLangevinAlgorithm-kwargs0] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[1-SliceSampler-kwargs2] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[1-UncalibratedLangevin-kwargs3] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[2-MetropolisAdjustedLangevinAlgorithm-kwargs0] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[2-SliceSampler-kwargs2] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[2-UncalibratedLangevin-kwargs3] - OverflowError: Python int too large to convert to C long

Hence: https://github.com/pyro-ppl/numpyro/pull/1817/commits/038dd84ad2ad51ba2d0ae8f8add1f3893c16db63

juanitorduz commented 2 weeks ago

Ok! Finally is 🟢!

Shall we revert these changes once the JAX issue is fixed and released? Similarly for tfp next release?