qutip / qutip-jax

JAX backend for QuTiP
BSD 3-Clause "New" or "Revised" License
15 stars 7 forks source link

Regarding Monte-Carlo quantum trajectories and measurements of observables on final states #47

Open Sampreet opened 1 month ago

Sampreet commented 1 month ago

Description

While trying to implement mcsolver with diffrax (to support higher dimensional Hilbert spaces with jax[cuda12]) and performing periodic measurements with the final state obtained from each run, the following issues are observed (one after the other, when resolved sequentially):

  1. DiffraxIntegrator is not added to the MCSolver class.
  2. qutip_jax.reshape.split_columns_jaxarray (line 58) takes 1 positional argument, but qutip.core.qobj.Qobj.eigenstates passes 2 arguments (line 1530) when it is called inside qutip.measurement.measurement_statistics_observable (line 224).
  3. qutip.measurement.measurement_statistics_observable line 242 throws an error saying multidimensional indexing using lists are deprecated with lax_numpy (https://github.com/google/jax/issues/4564).

In my forked versions of qutip and qutip-jax, I have made a few changes to resolve the issues to support periodic measurements with Monte-Carlo quantum trajectories as detailed below. Kindly let me know if there are any known issues in adding DiffraxIntegrator to MCSolver. Also, are there any particular tests required to be added (e.g., in qutip.tests.solver.test_mcsolve or qutip_jax.tests.test_ode) before submitting a PR? As far as I understand, MCIntegrator does not provide any option to pass a seed for its generator, and each trajectory evolves differently, hence only the ensembled average results can be compared using larger values of ntraj and higher comparison tolerances.

Snippet to Reproduce the Issues

Below is a simplified demonstration using the Jaynes-Cumming (JC) model from QuTiP's mcsolve tutorial but with a time-dependent interaction term:

# JC system
N = 100
with qutip.CoreOptions(default_dtype='jaxdia'):
    a = qutip.tensor(qutip.qeye(2), qutip.destroy(N))
    sm = qutip.tensor(qutip.destroy(2), qutip.qeye(N))
    H_0 = 2.0 * jax.numpy.pi * a.dag() * a + 2.0 * jax.numpy.pi * sm.dag() * sm
    # ``H_1`` is the time-dependent part with coefficient ``H_1_fn`` and operator ``H_1_op``
    H_1_op = sm * a.dag() + sm.dag() * a
    H_1_fn = lambda t: 2.0 * jax.numpy.pi * 0.25 * jax.numpy.cos(2.0 * jax.numpy.pi * t)
    H = [H_0, [H_1_op, qutip_jax.qobjevo.JaxJitCoeff(H_1_fn)]]
# DiffraxIntegrator converts the state to JaxArray anyway, hence not using JaxDia
state = qutip.tensor(qutip.fock(2, 0, dtype='jax'), qutip.fock(N, 8, dtype='jax'))
c_ops = [jax.numpy.sqrt(0.1) * a]
e_ops = [a.dag() * a, sm.dag() * sm]

# # uncomment to resolve *Issue 1* by adding ``DiffraxIntegrator`` to ``MCSolver``
# qutip.MCSolver.add_integrator(qutip_jax.ode.DiffraxIntegrator, 'diffrax')

# single trajectory instance of mcsolve  using ``diffrax.Dopri5`` and ``diffrax.PIDController``
# with QuTiP's default ``atol`` and ``rtol`` (https://github.com/qutip/qutip-jax/issues/26)
tlist = jax.numpy.linspace(0.0, 10.0, 200)
result = qutip.mcsolve(H, state, tlist, c_ops, e_ops, ntraj=1, options=dict(
    method = 'diffrax',
    solver = diffrax.Dopri5(),
    stepsize_controller = diffrax.PIDController(atol=1e-8, rtol=1e-6),
    store_states = True,
    keep_runs_results = True
))
# measurement of the final state
position, state = qutip.measurement.measure_observable(
    state=result.states[0][-1],
    op=(a.dag() + a) / jax.numpy.sqrt(2)
)

Changes Made

To Resolve Issue 1:

qutip_jax.ode is updated with:

...
from qutip.solver.mcsolve import MCSolver
...
MCSolver.add_integrator(DiffraxIntegrator, 'diffrax')
...

To Resolve Issue 2:

qutip_jax.reshape line 58 is changed to:

def split_columns_jaxarray(matrix, copy=False):

Although, the second argument will be redundant for qutip_jax.

To Resolve Issue 3:

qutip.measurement line 242 is changed to:

            values.append(np.mean(eigenvalues[np.array(present_group)]))

Quick Workaround for Issue 2 and Issue 3

A CSR form of the operators can be used with a dense form of the final state to perform the measurement.

position, state = qutip.measurement.measure_observable(
    state=state.to('dense'),
    op=((a.dag() + a) / jax.numpy.sqrt(2)).to('csr')
)

Environment Details

Software Version
QuTiP 5.0.1
Numpy 1.26.4
SciPy 1.13.0
matplotlib 3.8.4
Number of CPUs 12
BLAS Info Generic
IPython 8.24.0
Python 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:50:58) [GCC 12.3.0]
OS posix [linux]
Cython 3.0.10

Mon May 27 17:40:18 2024 IST

Additional Observations

  1. The speedup in the GPU is substantial for Hilbert space dimensions (N) greater than 10000. However, the eigenvalues/eigenstates calculations for N > 1000 are very slow (both in the CPU and GPU).
  2. Squaring a qutip_jax.jaxdia.JaxDia operator (e.g., a**2) converts it to qutip.core.data.dia.Dia, whereas multiplying the operator to itself (e.g., a * a) retains the datatype.
Ericgig commented 1 month ago

Thank you for your interest in the project. Your changes seems good for the particular issue. We don't have tests (yet) for which solvers work with jax and which not. But it would be nice to also add an example in the documentation on how to use mcsolve when adding it's support.

ps. We have a student (@rochisha0) working on having qutip-jax supported in more of qutip function this summer.

Sampreet commented 1 month ago

Great to know that the team is working on similar lines this summer. I am primarily trying to interface qutip-jax with sbx (https://github.com/araffin/sbx) for one of my current projects involving a multi-component continuous variable model (hence the requirement of larger total Hilbert spaces) and would be happy to contribute towards the development of qutip-jax in relevant workflows (especially JAX-based implementations for MultiTrajSolver and MCSolver to speed up intermediate steps).

We don't have tests (yet) for which solvers work with jax and which not. But it would be nice to also add an example in the documentation on how to use mcsolve when adding it's support.

Sure, I shall add an example in doc/source/solver.rst to update the documentation. As regards testing, I am working on comparing the jax and jaxdia results with the expected dia output (similar to the tests for MESolver in qutip_jax.tests.test_ode.py) using the seeds argument of MCSolver for each trajectory. Kindly let me know if that works (for now) and also if I need to connect with anyone working on this feature.

Also, the single trajectory runs of the time-independent JC model gives me the following runtimes with increasing Hilbert space dimension ($N$) of the cavity mode:

$N$ qutip qutip-jax (CPU) qutip-jax (GPU)
10000 1.893 5.431 6.157
17782 3.133 6.941 5.641
31622 5.855 9.404 5.341
56234 12.842 17.139 6.216
100000 23.698 28.956 6.748
177827 41.859 57.929 8.639
316227 73.380 101.08 12.207

PS. My VRAM gets exhausted beyond $N = 5 \times 10^5$, but one can guess the trend from the table.