Open bwohlberg opened 3 weeks ago
Jax release 0.4.29 appears to again have broken a component of scico.jax (full log)
scico.jax
============================= test session starts ============================== platform linux -- Python 3.10.14, pytest-8.2.2, pluggy-1.5.0 rootdir: /home/runner/work/scico/scico configfile: pytest.ini testpaths: scico/test, docs plugins: split-0.8.2 collected 3329 items / 3 skipped scico/test/flax/test_apply.py ....... [ 0%] scico/test/flax/test_checkpoints.py .... [ 0%] scico/test/flax/test_clu.py ..... [ 0%] scico/test/flax/test_examples_flax.py ss..ssssF......................... [ 1%] ..... [ 1%] scico/test/flax/test_flax.py .......................... [ 2%] [...] =================================== FAILURES =================================== __________________________ test_blur_data_generation ___________________________ > ??? _mt19937.pyx:180: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ /opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/jax/_src/core.py:766: in __index__ raise self.aval._index(self) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ self = ShapedArray(int64[]) arg = Traced<ShapedArray(int64[])>with<BatchTrace(level=1/0)> with val = Array([0], dtype=int64) batch_dim = 0 def error(self, arg): > raise TracerIntegerConversionError(arg) E jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[]. E This BatchTracer with object id 140001038232176 was created on line: E /home/runner/work/scico/scico/scico/test/flax/test_examples_flax.py:154 (random_img_gen) E See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError /opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/jax/_src/core.py:1508: TracerIntegerConversionError During handling of the above exception, another exception occurred: Traceback (most recent call last): File "_mt19937.pyx", line 180, in numpy.random._mt19937.MT19937._legacy_seeding File "/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/jax/_src/core.py", line 766, in __index__ raise self.aval._index(self) File "/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/jax/_src/core.py", line 1508, in error raise TracerIntegerConversionError(arg) jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[]. This BatchTracer with object id 140001038232176 was created on line: /home/runner/work/scico/scico/scico/test/flax/test_examples_flax.py:154 (random_img_gen) See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError During handling of the above exception, another exception occurred: jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: def test_blur_data_generation(): N = 32 nimg = 8 n = 3 # convolution kernel size blur_kernel = np.ones((n, n)) / (n * n) def random_img_gen(seed, size, ndata): np.random.seed(seed) return np.random.randn(ndata, size, size, 1) > img, blurn = generate_blur_data(nimg, N, blur_kernel, noise_sigma=0.01, imgfunc=random_img_gen) scico/test/flax/test_examples_flax.py:157: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ scico/flax/examples/data_generation.py:318: in generate_blur_data img = distributed_data_generation(imgfunc, size, nimg, False) scico/flax/examples/data_generation.py:382: in distributed_data_generation imgs = jax.vmap(imgenf, (0, None, None))(idx, size, ndata_per_proc) scico/test/flax/test_examples_flax.py:154: in random_img_gen np.random.seed(seed) numpy/random/mtrand.pyx:4806: in numpy.random.mtrand.seed ??? numpy/random/mtrand.pyx:250: in numpy.random.mtrand.RandomState.seed ??? _mt19937.pyx:168: in numpy.random._mt19937.MT19937._legacy_seeding ??? _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ > ??? E jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int64[]. E This BatchTracer with object id 140001038232176 was created on line: E /home/runner/work/scico/scico/scico/test/flax/test_examples_flax.py:154 (random_img_gen) E See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError _mt19937.pyx:185: TracerArrayConversionError =========================== short test summary info ============================ FAILED scico/test/flax/test_examples_flax.py::test_blur_data_generation - jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int64[]. This BatchTracer with object id 140001038232176 was created on line: /home/runner/work/scico/scico/scico/test/flax/test_examples_flax.py:154 (random_img_gen) See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError ====== 1 failed, 3280 passed, 24 skipped, 27 xfailed in 294.89s (0:04:54) ====== Error: Process completed with exit code 1.
scico/test/flax/test_examples_flax.py tests are also failing on jax 0.4.28 (nominally supported according to current requirements.txt) on GPU device.
scico/test/flax/test_examples_flax.py
jax
requirements.txt
Jax release 0.4.29 appears to again have broken a component of
scico.jax
(full log)