Needed to add additional Array wrapping in one spot in rotating_wave_approximation in which a raw JAX array was controlling a binary operation.
Also, in final_state_converter, JAX arrays were being fed into quantum info types, e.g. DensityMatrix, raising an error. I've changed this function to always convert the array to a numpy.array before wrapping in cls, as cls will always be a quantum info type, which can only accept numpy.arrays in any case. The doc string for this function has been updated to explain this.
Discovered that new JAX release advanced sparse support enough that LindbladModel.evaluate_rhs can now be reverse mode differentiated when in sparse mode. Removed this caveat from the documentation, and changed the test case for this function to test reverse-mode autodiff instead of forward mode (reverse is more restrictive so it's enough to just test this one). This test change was done by deleting the forward mode test case, which overrode an inherited reverse-mode test case.
Due to the above, bumped the minimum JAX version, and explained this in the release note.
Summary
Needed to add additional
Array
wrapping in one spot inrotating_wave_approximation
in which a raw JAX array was controlling a binary operation.Also, in
final_state_converter
, JAX arrays were being fed into quantum info types, e.g.DensityMatrix
, raising an error. I've changed this function to always convert the array to anumpy.array
before wrapping incls
, ascls
will always be a quantum info type, which can only acceptnumpy.array
s in any case. The doc string for this function has been updated to explain this.