qiskit-community / qiskit-dynamics

Tools for building and solving models of quantum systems in Qiskit
https://qiskit-community.github.io/qiskit-dynamics/
Apache License 2.0
105 stars 61 forks source link

jax_odeint seems to not do any evolution when the amp at the first "dt" is 0. #238

Closed ArbitRandomUser closed 1 year ago

ArbitRandomUser commented 1 year ago

Informations

What is the current behavior?

if the pulse amplitude in the first "dt" is 0 , the resultant probabilities is always {000: no_shots} . I.E there seems to be probability-amplitude only on the 000 basis whatever be the evolution.

Steps to reproduce the problem

run attached file mwe.py the example plays a constant pulse with amnplitude 0 for dt and then amplitude 1 for 100 dt.

toggle comment line 80 , to see the difference between . one can change the 100 to different values and see that whatever be the duration the subsequent pulse is played the result is always {'000':no_shots}

What is the expected behavior?

some probabilities should develop on all the basis , however whatever be the schedule if the amp during the first dt is 0 the result is always {'000':no_shots}. This is not restricted to playing pulse.Constant , playing pulse.Waveform too gives similar bug when the amplitude starts of at 0.

This bug is restricted to passing 'jax_odeint' to the solver. using numpy as the backend does not give the same problem

Suggested solutions

something to do with jax_odeint , idk really !.

mwe.py

import matplotlib.pyplot as plt
from qiskit import pulse
import numpy as np
import jax,numpy as np
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
from qiskit import QuantumCircuit
from qiskit.circuit import Gate
from qiskit import transpile
from qiskit_dynamics import Solver
from qiskit_dynamics import DynamicsBackend
from qiskit_dynamics.array import Array
Array.set_default_backend('jax')

dim = 3
v0 = 4.86e9
anharm0 = -0.32e9
r0 = 0.22e9
v1 = 4.97e9
anharm1 = -0.32e9
r1 = 0.26e9
v2 = 4.51e9
anharm2 = -0.32e9
r2 = 0.25e9
J = 0.011e9

a = np.diag(np.sqrt(np.arange(1, dim)), 1)
adag = np.diag(np.sqrt(np.arange(1, dim)), -1)
N = np.diag(np.arange(dim))
ident = np.eye(dim, dtype=complex)
full_ident = np.eye(dim**3, dtype=complex)
N0 = np.kron(ident,np.kron(ident, N))
N1 = np.kron(ident,np.kron(N, ident))
N2 = np.kron(N, np.kron(ident,ident))
a0 = np.kron(ident,np.kron(ident, a))
a1 = np.kron(ident,np.kron(a,ident))
a2 = np.kron(a, np.kron(ident,ident))
a0dag = np.kron(ident, np.kron(ident, adag))
a1dag = np.kron(ident, np.kron(adag, ident))
a2dag = np.kron(adag,  np.kron(ident, ident))

static_ham0 = 2 * np.pi * v0 * N0 + np.pi * anharm0 * N0 * (N0 - full_ident)
static_ham1 = 2 * np.pi * v1 * N1 + np.pi * anharm1 * N1 * (N1 - full_ident)
static_ham2 = 2 * np.pi * v2 * N2 + np.pi * anharm2 * N2 * (N2 - full_ident)

static_ham_full = static_ham0 + static_ham1 + static_ham2 + 2 * np.pi * J * ((a0 + a0dag) @ (a1 + a1dag) @ (a2+a2dag))

drive_op0 = 2 * np.pi * r0 * (a0 + a0dag)
drive_op1 = 2 * np.pi * r1 * (a1 + a1dag)
drive_op2 = 2 * np.pi * r2 * (a2 + a2dag)

# build solver
dt = 1/4.5e9

solver = Solver(
    static_hamiltonian=static_ham_full,
    hamiltonian_operators=[drive_op0, drive_op1, drive_op2, drive_op0, drive_op1, drive_op2],
    rotating_frame=static_ham_full,
    hamiltonian_channels=["d0", "d1","d2" ,"u0", "u1", "u2"],
    channel_carrier_freqs={"d0": v0, "d1": v1, "d2":v2 , "u0": v1, "u1": v0, "u2": v2},
    dt=dt,
)

# Consistent solver option to use throughout notebook
solver_options = {"method": "jax_odeint", "atol": 1e-6, "rtol": 1e-8}
backend = DynamicsBackend(
    solver=solver,
    subsystem_dims=[dim, dim, dim], # for computing measurement data
    solver_options=solver_options,#solver_options, # to be used every time run is called
)

##
circ = QuantumCircuit(3,3)
with pulse.build() as sched:
    with pulse.align_sequential():
        pulse.play(pulse.Constant(1,0.0),pulse.DriveChannel(0))##TOGGLE THIS COMMENT
        pulse.play(pulse.Constant(100,1.0),pulse.DriveChannel(0))
cusgate = Gate('cusgate',3,[])
circ.add_calibration(cusgate,qubits=(0,1,2),schedule=sched)
circ.append(cusgate,[0,1,2])
circ.measure(0,0)
circ.measure(1,1)
circ.measure(2,2)
circ = transpile(circ,backend)
job = backend.run(circ, shots=1000)
result =job.result()
print(result.get_counts())
DanPuzzuoli commented 1 year ago

Add 'hmax': dt to solver_options, as is done in the DynamicsBackend tutorial. This will limit the step size to never exceed the width of a sample, and should solve this problem.

'jax_odeint' seems to be pretty aggressive in the rate at which it increases the step size. As the simulation is by-default performed in the rotating frame of the drift, if the pulse amplitudes are 0 at some point, then the Hamiltonian in this frame is actually fully 0, and hence there is no evolution. During this period the solver will rapidly increase the step size to the point that it will just jump the end. Limiting the max step size as above fixes this issue.

I'm going to close this for now, but if this doesn't solve your problem we can reopen.

ArbitRandomUser commented 1 year ago

solved