lanl / scico

Scientific Computational Imaging COde
BSD 3-Clause "New" or "Revised" License
90 stars 17 forks source link

Test failure in `scico.jax` with latest `jax` version 0.4.29 #535

Open bwohlberg opened 3 weeks ago

bwohlberg commented 3 weeks ago

Jax release 0.4.29 appears to again have broken a component of scico.jax (full log)

============================= test session starts ==============================
platform linux -- Python 3.10.14, pytest-8.2.2, pluggy-1.5.0
rootdir: /home/runner/work/scico/scico
configfile: pytest.ini
testpaths: scico/test, docs
plugins: split-0.8.2
collected 3329 items / 3 skipped
scico/test/flax/test_apply.py .......                                    [  0%]
scico/test/flax/test_checkpoints.py ....                                 [  0%]
scico/test/flax/test_clu.py .....                                        [  0%]
scico/test/flax/test_examples_flax.py ss..ssssF......................... [  1%]
.....                                                                    [  1%]
scico/test/flax/test_flax.py ..........................                  [  2%]
[...]
=================================== FAILURES ===================================
__________________________ test_blur_data_generation ___________________________
>   ???
_mt19937.pyx:180: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/jax/_src/core.py:766: in __index__
    raise self.aval._index(self)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
self = ShapedArray(int64[])
arg = Traced<ShapedArray(int64[])>with<BatchTrace(level=1/0)> with
  val = Array([0], dtype=int64)
  batch_dim = 0
    def error(self, arg):
>     raise TracerIntegerConversionError(arg)
E     jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[].
E     This BatchTracer with object id 140001038232176 was created on line:
E       /home/runner/work/scico/scico/scico/test/flax/test_examples_flax.py:154 (random_img_gen)
E     See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError
/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/jax/_src/core.py:1508: TracerIntegerConversionError
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "_mt19937.pyx", line 180, in numpy.random._mt19937.MT19937._legacy_seeding
  File "/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/jax/_src/core.py", line 766, in __index__
    raise self.aval._index(self)
  File "/opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/jax/_src/core.py", line 1508, in error
    raise TracerIntegerConversionError(arg)
jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int64[].
This BatchTracer with object id 140001038232176 was created on line:
  /home/runner/work/scico/scico/scico/test/flax/test_examples_flax.py:154 (random_img_gen)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError
During handling of the above exception, another exception occurred:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
    def test_blur_data_generation():
        N = 32
        nimg = 8
        n = 3  # convolution kernel size
        blur_kernel = np.ones((n, n)) / (n * n)

        def random_img_gen(seed, size, ndata):
            np.random.seed(seed)
            return np.random.randn(ndata, size, size, 1)

>       img, blurn = generate_blur_data(nimg, N, blur_kernel, noise_sigma=0.01, imgfunc=random_img_gen)
scico/test/flax/test_examples_flax.py:157: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
scico/flax/examples/data_generation.py:318: in generate_blur_data
    img = distributed_data_generation(imgfunc, size, nimg, False)
scico/flax/examples/data_generation.py:382: in distributed_data_generation
    imgs = jax.vmap(imgenf, (0, None, None))(idx, size, ndata_per_proc)
scico/test/flax/test_examples_flax.py:154: in random_img_gen
    np.random.seed(seed)
numpy/random/mtrand.pyx:4806: in numpy.random.mtrand.seed
    ???
numpy/random/mtrand.pyx:250: in numpy.random.mtrand.RandomState.seed
    ???
_mt19937.pyx:168: in numpy.random._mt19937.MT19937._legacy_seeding
    ???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
>   ???
E   jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int64[].
E   This BatchTracer with object id 140001038232176 was created on line:
E     /home/runner/work/scico/scico/scico/test/flax/test_examples_flax.py:154 (random_img_gen)
E   See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
_mt19937.pyx:185: TracerArrayConversionError
=========================== short test summary info ============================
FAILED scico/test/flax/test_examples_flax.py::test_blur_data_generation - jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int64[].
This BatchTracer with object id 140001038232176 was created on line:
  /home/runner/work/scico/scico/scico/test/flax/test_examples_flax.py:154 (random_img_gen)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
====== 1 failed, 3280 passed, 24 skipped, 27 xfailed in 294.89s (0:04:54) ======
Error: Process completed with exit code 1.
bwohlberg commented 3 weeks ago

scico/test/flax/test_examples_flax.py tests are also failing on jax 0.4.28 (nominally supported according to current requirements.txt) on GPU device.