qiskit-community / qiskit-dynamics

Tools for building and solving models of quantum systems in Qiskit
https://qiskit-community.github.io/qiskit-dynamics/
Apache License 2.0
105 stars 61 forks source link

Remove bound on JAX/Diffrax versions #266

Closed DanPuzzuoli closed 1 year ago

DanPuzzuoli commented 1 year ago

Summary

Closes #190

As it is unclear when the bug causing an error in the perturbation module will get fixed in JAX, I just tried returning to the issue and have figured out a simple workaround to bypass the bug. I realistically should have tried this sooner but I didn't realize the issue would hang for this long.

The workaround fixes the perturbation module, but there are now many errors/warnings coming up in many different tests. This is to be expected - there have been over 10 minor releases of JAX since we put a bound on the version. I'll need to make my way through each folder and figure out how to fix it.

Details and comments

I'm working through each submodule to fix errors/warnings in the tests (strikethrough indicates no warnings/errors being raised when running tests, and no comment or strikethrough means the module hasn't been checked yet).

DanPuzzuoli commented 1 year ago

It seems the complex -> real casting warnings only ever occur in differentiated functions, which means this is probably some weird behaviour with JAX transformations (i.e. these warnings don't get raised if the function is not differentiated). Need to try to find a minimal example to understand what's going on.

Edit: Okay it turns out that this warning gets raised when computing the gradient of a function that involves a tensordot of a real array with a complex array. Will see what the JAX developers say about this: it seems like this should either be raised when not computing the gradient (for consistency), or should not be raised during gradient computation.

Asked here: https://github.com/google/jax/discussions/18133

DanPuzzuoli commented 1 year ago

As it currently stands: