Open Sampreet opened 1 month ago
Thank you for your interest in the project. Your changes seems good for the particular issue. We don't have tests (yet) for which solvers work with jax and which not. But it would be nice to also add an example in the documentation on how to use mcsolve when adding it's support.
ps. We have a student (@rochisha0) working on having qutip-jax supported in more of qutip function this summer.
Great to know that the team is working on similar lines this summer. I am primarily trying to interface qutip-jax
with sbx
(https://github.com/araffin/sbx) for one of my current projects involving a multi-component continuous variable model (hence the requirement of larger total Hilbert spaces) and would be happy to contribute towards the development of qutip-jax
in relevant workflows (especially JAX-based implementations for MultiTrajSolver
and MCSolver
to speed up intermediate steps).
We don't have tests (yet) for which solvers work with jax and which not. But it would be nice to also add an example in the documentation on how to use mcsolve when adding it's support.
Sure, I shall add an example in doc/source/solver.rst
to update the documentation.
As regards testing, I am working on comparing the jax
and jaxdia
results with the expected dia
output (similar to the tests for MESolver
in qutip_jax.tests.test_ode.py
) using the seeds
argument of MCSolver
for each trajectory.
Kindly let me know if that works (for now) and also if I need to connect with anyone working on this feature.
Also, the single trajectory runs of the time-independent JC model gives me the following runtimes with increasing Hilbert space dimension ($N$) of the cavity mode:
$N$ | qutip | qutip-jax (CPU) | qutip-jax (GPU) |
---|---|---|---|
10000 | 1.893 | 5.431 | 6.157 |
17782 | 3.133 | 6.941 | 5.641 |
31622 | 5.855 | 9.404 | 5.341 |
56234 | 12.842 | 17.139 | 6.216 |
100000 | 23.698 | 28.956 | 6.748 |
177827 | 41.859 | 57.929 | 8.639 |
316227 | 73.380 | 101.08 | 12.207 |
PS. My VRAM gets exhausted beyond $N = 5 \times 10^5$, but one can guess the trend from the table.
Description
While trying to implement
mcsolver
withdiffrax
(to support higher dimensional Hilbert spaces withjax[cuda12]
) and performing periodic measurements with the final state obtained from each run, the following issues are observed (one after the other, when resolved sequentially):DiffraxIntegrator
is not added to theMCSolver
class.qutip_jax.reshape.split_columns_jaxarray
(line 58) takes 1 positional argument, butqutip.core.qobj.Qobj.eigenstates
passes 2 arguments (line 1530) when it is called insidequtip.measurement.measurement_statistics_observable
(line 224).qutip.measurement.measurement_statistics_observable
line 242 throws an error saying multidimensional indexing using lists are deprecated withlax_numpy
(https://github.com/google/jax/issues/4564).In my forked versions of
qutip
andqutip-jax
, I have made a few changes to resolve the issues to support periodic measurements with Monte-Carlo quantum trajectories as detailed below. Kindly let me know if there are any known issues in addingDiffraxIntegrator
toMCSolver
. Also, are there any particular tests required to be added (e.g., inqutip.tests.solver.test_mcsolve
orqutip_jax.tests.test_ode
) before submitting a PR? As far as I understand,MCIntegrator
does not provide any option to pass aseed
for itsgenerator
, and each trajectory evolves differently, hence only the ensembled average results can be compared using larger values ofntraj
and higher comparison tolerances.Snippet to Reproduce the Issues
Below is a simplified demonstration using the Jaynes-Cumming (JC) model from QuTiP's mcsolve tutorial but with a time-dependent interaction term:
Changes Made
To Resolve Issue 1:
qutip_jax.ode
is updated with:To Resolve Issue 2:
qutip_jax.reshape
line 58 is changed to:Although, the second argument will be redundant for
qutip_jax
.To Resolve Issue 3:
qutip.measurement
line 242 is changed to:Quick Workaround for Issue 2 and Issue 3
A CSR form of the operators can be used with a dense form of the final state to perform the measurement.
Environment Details
Mon May 27 17:40:18 2024 IST
Additional Observations
qutip_jax.jaxdia.JaxDia
operator (e.g.,a**2
) converts it toqutip.core.data.dia.Dia
, whereas multiplying the operator to itself (e.g.,a * a
) retains the datatype.