Updating operator_collections.py to work with arraylias.
General list of changes:
All operator collections corresponding to numpy, jax, and jax_sparse are merged into a single class using arraylias dispatching (one each for OperatorCollection, LindbladCollection, and VectorizedLindbladCollection. However, I've decided to keep scipy_sparse as a separate special case in each instance, as the interface is just too different from the others. It would be nice to have it all handled in a single class, but I'm worried that this would require building too much scipy sparse into our numpy alias that wouldn't be used anywhere else.
I've added registration of some more functions for "jax_sparse": conjugate and transpose.
I've added a linear_combo function to the registration for "numpy", "jax", and "jax_sparse". This is just tensordot for the first two but for jax_sparse a somewhat-hacky implementation is necessary.
Updated _preferred_lib to prefer "jax_sparse" if present.
The model classes (GeneratorModel, HamiltonianModel, and LindbladModel) are currently broken. I've commented out the lines using the old versions of the operator collection classes to avoid errors. I'll address these in the next PR updating these classes.
I've rearranged the tests for operator collections to reflect the new class structure. I tried to keep track and make sure all of the tests got transferred over.
One more general change:
This is somewhat orthogonal to the main point of this PR, but as we are making wholesale changes it make sense to address it now: I'm taking this opportunity to make the operator collections AND the model classes immutable. This should simplify various things, even in the Solver class, in which signals need to be added to, and removed from, models. This PR won't be the full change, but I'm going to mark this PR as closing #243, in which the immutability of model classes would have prevented the weird behaviour from occurring in the first place. For this PR, that means getting rid of all of the setter methods for the different operator types. I will do this in the model classes in a subsequent PR as well.
Details and comments
Current list of test files we need to make sure are passing before merge to feature branch:
test_rotating_frame.py
test_operator_collections.py
test_alias.py
We should keep track of this list in subsequent PRs.
One technical issue:
Unfortunately, for LindbladCollection, it isn't possible to directly pass JAX sparse arrays at construction. This was always the case, but I had hoped to change that here. The handling of the n_batch argument at BCOO construction needs to be set to a particular value to get it to work here, and this isn't something I want the user to need to do. For now, they will need to pass dense arrays and select array_library="jax_sparse".
Summary
Updating operator_collections.py to work with arraylias.
General list of changes:
numpy
,jax
, andjax_sparse
are merged into a single class using arraylias dispatching (one each forOperatorCollection
,LindbladCollection
, andVectorizedLindbladCollection
. However, I've decided to keepscipy_sparse
as a separate special case in each instance, as the interface is just too different from the others. It would be nice to have it all handled in a single class, but I'm worried that this would require building too much scipy sparse into our numpy alias that wouldn't be used anywhere else."jax_sparse"
:conjugate
andtranspose
.linear_combo
function to the registration for"numpy"
,"jax"
, and"jax_sparse"
. This is justtensordot
for the first two but forjax_sparse
a somewhat-hacky implementation is necessary._preferred_lib
to prefer"jax_sparse"
if present.GeneratorModel
,HamiltonianModel
, andLindbladModel
) are currently broken. I've commented out the lines using the old versions of the operator collection classes to avoid errors. I'll address these in the next PR updating these classes.One more general change:
Solver
class, in which signals need to be added to, and removed from, models. This PR won't be the full change, but I'm going to mark this PR as closing #243, in which the immutability of model classes would have prevented the weird behaviour from occurring in the first place. For this PR, that means getting rid of all of thesetter
methods for the different operator types. I will do this in the model classes in a subsequent PR as well.Details and comments
Current list of test files we need to make sure are passing before merge to feature branch:
We should keep track of this list in subsequent PRs.
One technical issue:
LindbladCollection
, it isn't possible to directly pass JAX sparse arrays at construction. This was always the case, but I had hoped to change that here. The handling of then_batch
argument atBCOO
construction needs to be set to a particular value to get it to work here, and this isn't something I want the user to need to do. For now, they will need to pass dense arrays and selectarray_library="jax_sparse"
.