FUNSOR_BACKEND=jax pytest test_distribution.py -k test_categorical_event_dim_conversion test fails with the new jax version 0.3.10:
E AssertionError: Expected:
E [-16.55620432 -16.0037727 -14.16045645 -16.39053217
nan]
E Actual:
E [-16.55620432 -16.0037727 -14.16045645 -16.39053217
nan]
E assert nan < 1e-06
E + where nan = <built-in method max of numpy.ndarray object at 0x7f49682d61b0>()
E + where <built-in method max of numpy.ndarray object at 0x7f49682d61b0> = (array([nan]) / (1e-06 + array([nan]))).max
E + where array([nan]) = abs(array([nan]))
../funsor/testing.py:204: AssertionError
============================ warnings summary =============================
../../../anaconda3/envs/pyro-dev/lib/python3.8/site-packages/flatbuffers/compat.py:19
/home/ordabayev/anaconda3/envs/pyro-dev/lib/python3.8/site-packages/flatbuffers/compat.py:19: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
import imp
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================= short test summary info =========================
FAILED test_distribution.py::test_categorical_event_dim_conversion[(4, 7)-(5,)]
FAILED test_distribution.py::test_categorical_event_dim_conversion[(4, 1, 7)-(5,)]
========= 2 failed, 16 passed, 558 deselected, 1 warning in 5.98s =========
FUNSOR_BACKEND=jax pytest test_distribution.py -k test_categorical_event_dim_conversion
test fails with the new jax version 0.3.10: