qiskit-community / qiskit-dynamics

Tools for building and solving models of quantum systems in Qiskit
https://qiskit-community.github.io/qiskit-dynamics/
Apache License 2.0
105 stars 61 forks source link

Add warning for JAX versions on import of Dynamics #232

Closed DanPuzzuoli closed 1 year ago

DanPuzzuoli commented 1 year ago

Summary

Closes #231

This PR updates the JAX type registration in dispatch, and also adds a warning at JAX import time about the version bounds for JAX.

Details and comments

The docs build has also been set to run with JAX version 0.4.3 to avoid this warning.

DanPuzzuoli commented 1 year ago

@wshanks I've implemented your suggestion for the warning to not trigger if the JAX version is 0.4.4, 0.4.5, or 0.4.6 and the user has already set the os environment variable. I've reverted the docs to use JAX 0.4.6 as the warning will no longer be triggered.

DanPuzzuoli commented 1 year ago

So I still ended up needing to set os.environ["JAX_JIT_PJIT_API_MERGE"] = "0" in docs/conf.py to avoid the warning in the docs build, but luckily that worked. I've put that in with a comment, and have added a description of this PR to issue #190 (where I'm keeping track of all the little changes that we've had to make because of these JAX issues).

I was actually also able to completely remove adding DeviceArray to JAX_TYPES - Array and Tracer turn out to be all that's necessary these days.