Closed DanPuzzuoli closed 1 year ago
@wshanks I've implemented your suggestion for the warning to not trigger if the JAX version is 0.4.4, 0.4.5, or 0.4.6 and the user has already set the os
environment variable. I've reverted the docs to use JAX 0.4.6 as the warning will no longer be triggered.
So I still ended up needing to set os.environ["JAX_JIT_PJIT_API_MERGE"] = "0"
in docs/conf.py
to avoid the warning in the docs build, but luckily that worked. I've put that in with a comment, and have added a description of this PR to issue #190 (where I'm keeping track of all the little changes that we've had to make because of these JAX issues).
I was actually also able to completely remove adding DeviceArray
to JAX_TYPES
- Array
and Tracer
turn out to be all that's necessary these days.
Summary
Closes #231
This PR updates the JAX type registration in
dispatch
, and also adds a warning at JAX import time about the version bounds for JAX.Details and comments
The docs build has also been set to run with JAX version 0.4.3 to avoid this warning.