Qiskit-Extensions / qiskit-dynamics

Tools for building and solving models of quantum systems in Qiskit
https://qiskit-extensions.github.io/qiskit-dynamics/
Apache License 2.0
97 stars 60 forks source link

Case in which DynamicsBackend.run hangs #185

Closed DanPuzzuoli closed 1 year ago

DanPuzzuoli commented 1 year ago

Informations

What is the current behavior?

This is a bizarre bug that I haven't been able to get to the bottom of. The code is below, but the setup is:

With the above setup:

In the cases in which this happens, the weird thing about this is that DynamicsBackend.run seems to hang at the step that the ODE solver is called (when debugging and stepping over things it gets stuck here), however all DynamicsBackend.run does up to that point is create a DynamicsJob instance then call Solver.solve. I.e. for the purposes of this bug DynamicsBackend.run is basically a light wrapper around Solver.solve.

Steps to reproduce the problem

Setting up things up:

# Configure to use JAX internally
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_platform_name', 'cpu')
from qiskit_dynamics.array import Array
Array.set_default_backend('jax')

import numpy as np

dim = 2

v0 = 5.1e9
anharm0 = -0.33e9
r0 = 0.1e9

v1 = 5.15e9
anharm1 = -0.33e9
r1 = 0.1e9

J = 0.002e9

a = np.diag(np.sqrt(np.arange(1, dim)), 1)
adag = np.diag(np.sqrt(np.arange(1, dim)), -1)
N = np.diag(np.arange(dim))

ident = np.eye(dim, dtype=complex)
full_ident = np.eye(dim**2, dtype=complex)

N0 = np.kron(ident, N)
N1 = np.kron(N, ident)

a0 = np.kron(ident, a)
a1 = np.kron(a, ident)

a0dag = np.kron(ident, adag)
a1dag = np.kron(adag, ident)

static_ham0 = 2 * np.pi * v0 * N0 + np.pi * anharm0 * N0 * (N0 - full_ident)
static_ham1 = 2 * np.pi * v1 * N1 + np.pi * anharm1 * N1 * (N1 - full_ident)

static_ham_full = static_ham0 + static_ham1 + 2 * np.pi * J * ((a0 + a0dag) @ (a1 + a1dag))

drive_op0 = 2 * np.pi * r0 * (a0 + a0dag)

from qiskit_dynamics import Solver

# build solver
dt = 1/4.5e9

solver = Solver(
    static_hamiltonian=static_ham_full,
    hamiltonian_operators=[drive_op0],
    rotating_frame=np.diag(static_ham_full),
    hamiltonian_channels=['d0'],
    channel_carrier_freqs={'d0': v0},
    dt=dt,
    #evaluation_mode="sparse"
)

from qiskit import pulse

schedules = []

gauss = pulse.library.Gaussian(
    128, 1., 256, name="Parametric Gauss"
)

with pulse.build() as schedule:
    with pulse.align_right():
        pulse.play(gauss, pulse.DriveChannel(0))
        pulse.acquire(duration=1, qubit_or_channel=0, register=pulse.MemorySlot(0))

Simulate the schedules using Solver.solve works:

from qiskit.quantum_info import Statevector

y0 = np.zeros(dim**2, dtype=complex)
y0[0] = 1.

solver.solve(
    t_span=np.array([0., schedule.duration*dt]), 
    y0=Statevector(y0), 
    signals=schedule,
    method='jax_odeint',
    atol=1e-6, rtol=1e-8
)

Building a DynamicsBackend with this Solver instance and calling run (with seemingly the same options) leads to a never-ending computation:

from qiskit_dynamics import DynamicsBackend

solver_options = {'method': 'jax_odeint', 'atol': 1e-6, 'rtol': 1e-8}

backend = DynamicsBackend(
    solver=solver,
    subsystem_dims=[dim, dim],
    solver_options=solver_options,
)

backend.run(schedule).result()

Variations of the above code that work or don't

This was discovered when I was trying to see how a sparse simulation performed, hence the sparse configuration of the Solver.

Additional tested situations:

What's confusing is that the above would normally suggest some issue with the differential equation solving... however Solver.solve still always works in all of the above configurations. Only DynamicsBackend.run hangs, and it specifically hangs when the ODE solver is called!

I've also disabled jitting by changing the Solver.solve code to always skip the auto-jitting behaviour, and the issue persists, so it doesn't seem to be anything weird going on with the internal auto-jitting code.

Further steps

I'll keep tryign to debug with jitting disabled to see where the difference might lie.

DanPuzzuoli commented 1 year ago

This issue actually even persists if I drop the JAX configuration and use a scipy solver, which is "good" in that it means it won't be necessary to dig further into JAX tracing or anything like that.

DanPuzzuoli commented 1 year ago

Okay bug has been figured out. It is somehow a huge mistake in the code that miraculously doesn't impact DynamicsBackend.run results, but in some cases will accidentally make the computation excessively expensive.

Amazingly, the t_span values being fed into Solver.solve within DynamicsBackend.run are the sample numbers, not the sample numbers * dt. So, in the above example, the acquire instruction occurs at sample number 127, so the integration time should be 127 * dt == 2.822222222222222e-08, but the DynamicsBackend is simply passing the integration interval as [0, 127].

This is a huge mistake, but amazingly it doesn't have a mathematical impact on the results of DynamicsBackend:

Further, in cases where the rotating frame was set to be the full static Hamiltonian, after 127 * dt, the RHS of the DE is exactly 0., so even though the integration was proceeding way beyond the desired limit (by 10 orders of magnitude), the variable step solvers would quickly start taking extremely large steps until they hit the endpoint. This resulted in no obvious difference in the simulation run time for the correct v.s. incorrect integration intervals.

The hanging was observed initially when first setting up the a DynamicsBackend with a Solver in sparse mode, as in this case the rotating frame is not set to the full static Hamiltonian, but to the diagonal part only. As a result, the RHS is no longer 0. for all times after 127 * dt, and the solver is therefore forced to continue to take small steps (hence it seeming to hang forever).

This will be easily fixed by correctly multiplying the integration times by dt.