The recent JAX 0.4.4 update has broken testing for the perturbation module. The discussion in the JAX repo is here, and has similar characteristics to a similar issue that was introduced and discussed a year ago here (the same minimal example works in both cases).
The JAX changelog indicates that running the lines:
import os
os.environ["JAX_JIT_PJIT_API_MERGE"] = "0"
before importing JAX will disable the major change associated with this release that seems to be causing the problem.
This PR adds this to the __init__.py file in the testing folder, to get the CI tests to run.
Summary
The recent JAX 0.4.4 update has broken testing for the perturbation module. The discussion in the JAX repo is here, and has similar characteristics to a similar issue that was introduced and discussed a year ago here (the same minimal example works in both cases).
The JAX changelog indicates that running the lines:
before importing JAX will disable the major change associated with this release that seems to be causing the problem.
This PR adds this to the
__init__.py
file in the testing folder, to get the CI tests to run.