Closed juanitorduz closed 2 weeks ago
There is a new test failing probably because the new jax and/or numpy releases
__________________ test_discrete_site_without_infer_enumerate __________________
def test_discrete_site_without_infer_enumerate():
def model():
numpyro.sample("x", dist.Bernoulli(0.5))
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
with pytest.warns(FutureWarning, match="enumerated sites"):
> mcmc.run(random.PRNGKey(0))
E FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
test/infer/test_mcmc.py:1104: FutureWarning
I added a different match group in https://github.com/pyro-ppl/numpyro/pull/1817/commits/aa30c69b57184b6ce16d0ce19d2892186fd5e249 but I think it is essential to address these warnings. Especially because we are also getting
DeprecationWarning: numpy.core.numeric is deprecated and has been renamed to numpy._core.numeric. The numpy._core namespace contains private NumPy internals and its use is discouraged, as NumPy internals can change without warning in any release. In practice, most real-world usage of numpy.core is to access functionality in the public NumPy API. If that is the case, use the public NumPy API. If not, you are using NumPy internals. If you would still like to access an internal attribute, use numpy._core.numeric.normalize_axis_tuple.
from numpy.core.numeric import normalize_axis_tuple
test/infer/test_mcmc.py::test_discrete_site_without_infer_enumerate
/Users/juanitorduz/Documents/envs/numpyro-env/lib/python3.12/site-packages/jax/_src/linear_util.py:192: DeprecationWarning: Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy.clip is deprecated. Please use 'x', 'min', and 'max' respectively instead.
It seems
FutureWarning: unhashable type: <class 'jax._src.interpreters.batching.BatchTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
is all over the place now 😓
Thanks @juanitorduz. I'm looking at them.
Thanks @juanitorduz. I'm looking at them.
ok! You can either push to this branch or create a new one if needed
It turns out that in funsor, we have some checks for tracers to be Hashable. I don't think that the new behavior will cause issues: it is fine to let arrays to be either hashable or unhashable. So I think we can simply filter out these warnings:
__init__
, at the top of the file:
import warnings
warnings.filterwarnings("ignore", message=".Attempting to hash a tracer.", category=FutureWarning)
+ in `pyproject.toml`: `"ignore:.*Attempting to hash a tracer:FutureWarning",` for pytest
ok! We are making progress 😅 ! Now we have
FAILED test/contrib/test_control_flow.py::test_scan - TypeError: body_fun output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', ['ShapedArray(float32[10])', 'ShapedArray(float32[10])', 'DIFFERENT ShapedArray(int32[], weak_type=True) vs. ShapedArray(float0[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])'], []).
FAILED test/contrib/test_control_flow.py::test_scan_svi - TypeError: body_fun output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', ['ShapedArray(float32[3,5])', 'DIFFERENT ShapedArray(int32[], weak_type=True) vs. ShapedArray(float0[])', 'ShapedArray(float32[5])'], []).
and
TypeError: body_fun output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', ['ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'DIFFERENT ShapedArray(int32[], weak_type=True) vs. ShapedArray(float0[])', 'ShapedArray(float32[])', 'ShapedArray(float32[])', 'ShapedArray(float32[10])'], []).
test/test_examples.py::test_cpu[holt_winters.py --T 4 --num-samples 10 --num-warmup 10 --num-chains 2] Running:
python examples/holt_winters.py --T 4 --num-samples 10 --num-warmup 10 --num-chains 2
I am so seeing
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
🤔
It seems like a type problem ... could this be again a numpy or jax recent change?
Hi @juanitorduz, I raised the upstream error in https://github.com/google/jax/issues/22045. For a fix, could you help me change every device_put(foo)
in https://github.com/pyro-ppl/numpyro/blob/master/numpyro/contrib/control_flow/scan.py to
tree_map(device_put, foo)
In https://github.com/pyro-ppl/numpyro/pull/1817/commits/19f82322060fd1b6ce31be1797f98b254610c2bf I still saw other test failing:
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[1-MetropolisAdjustedLangevinAlgorithm-kwargs0] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[1-SliceSampler-kwargs2] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[1-UncalibratedLangevin-kwargs3] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[2-MetropolisAdjustedLangevinAlgorithm-kwargs0] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[2-SliceSampler-kwargs2] - OverflowError: Python int too large to convert to C long
FAILED test/contrib/test_tfp.py::test_unnormalized_normal_chain[2-UncalibratedLangevin-kwargs3] - OverflowError: Python int too large to convert to C long
Hence: https://github.com/pyro-ppl/numpyro/pull/1817/commits/038dd84ad2ad51ba2d0ae8f8add1f3893c16db63
Ok! Finally is 🟢!
Shall we revert these changes once the JAX issue is fixed and released? Similarly for tfp next release?
Partially addresses https://github.com/pyro-ppl/numpyro/issues/1814 . We must keep in mind removing these
skip
statements once we see a new release.