I needed to update my Jax installation and found that Jax 0.4.35 no longer has jax._src.numpy.util._wraps, which was renamed in this PR to jax._src.numpy.util.implements. Changing the import and the decorator in _jax_idct.py resolved the ImportError and seemed to work fine in my own use case.
I needed to update my Jax installation and found that Jax 0.4.35 no longer has
jax._src.numpy.util._wraps
, which was renamed in this PR tojax._src.numpy.util.implements
. Changing the import and the decorator in _jax_idct.py resolved the ImportError and seemed to work fine in my own use case.