Qiskit / qiskit

Qiskit is an open-source SDK for working with quantum computers at the level of extended quantum circuits, operators, and primitives.
https://www.ibm.com/quantum/qiskit
Apache License 2.0
5.19k stars 2.35k forks source link

Write a Pulse Schedule Transformation that reduces all pulse channels to a single Waveform #6956

Open taalexander opened 3 years ago

taalexander commented 3 years ago

What is the expected enhancement?

As noted in this Slack conversation there is a desire to be able to reduce all PulseChannels in a Schedule/ScheduleBlock into a single waveform. This means compressing all Play/ShiftPhase/SetPhase/ShiftFrequency/SetFrequency instructions into a single Waveform and also recursively unrolling internally Calls.

It should be noted that there is likely relevant routines in the plotter which does similar transformations.

This could be added as a new transform module here.

snsunx commented 3 years ago

I traced back from Schedule.draw and found that the function gen_filled_waveform_stepwise in this script generates the waveform for a channel.

I tried the following code which plots the real and imaginary parts of all the pulses on a certain channel

import matplotlib.pyplot as plt

from qiskit.pulse import Schedule, Gaussian, Play, ShiftPhase, DriveChannel
from qiskit.visualization.pulse_v2 import device_info, stylesheet
from qiskit.visualization.pulse_v2.events import ChannelEvents
from qiskit.visualization.pulse_v2.generators import gen_filled_waveform_stepwise

# Construct the Gaussian pulse and drive channels
gaussian_pulse = Gaussian(duration=2688, amp=1, sigma=336)
drive_chan0 = DriveChannel(0)
drive_chan1 = DriveChannel(1)

# Construct a schedule with different phase shifts on the two channels
sched = Schedule()
sched += Play(gaussian_pulse, drive_chan0)
sched += ShiftPhase(1.2, drive_chan0)
sched += Play(gaussian_pulse, drive_chan0)
sched += Play(gaussian_pulse, drive_chan1)
sched += ShiftPhase(1.8, drive_chan1)
sched += Play(gaussian_pulse, drive_chan1)
# sched.draw();

# Generate waveforms on channel 0
chan_events = ChannelEvents.load_program(sched, drive_chan0)
waveforms = chan_events.get_waveforms()
line_data_lists = [gen_filled_waveform_stepwise(
                   waveform,
                   formatter=stylesheet.QiskitPulseStyle().formatter,
                   device=device_info.OpenPulseBackendInfo())
                   for waveform in waveforms]

# Plot the waveforms
plt.figure()
for line_data_list in line_data_lists:
    for line_data in line_data_list:
        if line_data.data_type == 'Waveform.Real':
            real, = plt.plot(line_data.xvals, line_data.yvals, color='C0', label='Real')
        elif line_data.data_type == 'Waveform.Imag':
            imag, = plt.plot(line_data.xvals, line_data.yvals, color='C1', label='Imag')
plt.xlabel("System Cycle Time (dt)")
plt.ylabel("Pulse Amplitude")
plt.legend(handles=[real, imag])
plt.show()

Basically the waveform sample values are stored in line_data.yvals. One slight problem is that the generated waveforms are stepwise as was implemented here, which can be resolved by simply not including those two lines for developing the current feature.

One thing I'm not sure is how ShiftFrequency and SetFrequency affect the waveforms. As far as I understand they won't affect the waveform sample values.

nkanazawa1989 commented 3 years ago

This is what the plotter calls https://github.com/Qiskit/qiskit-terra/blob/75e06dc915f764ea4e5c1a57097e980e9d01b119/qiskit/pulse/transforms/base_transforms.py#L25-L28 (flattening Call and other high-level transformations are done by this pass)

The target_qobj_transform is called here https://github.com/Qiskit/qiskit-terra/blob/75e06dc915f764ea4e5c1a57097e980e9d01b119/qiskit/visualization/pulse_v2/core.py#L254

Actually, ShiftFrequency and SetFrequency do affect the waveform. Indeed, a waveform needs to get phase factor (exp(i f_d t)) associated with its frame to be played. The frame is rotating at the frequency f which is defined by the backend.defaults().qubit_freq_est, otherwise updated by the SetFrequency and ShiftFrequency. This phase factor is usually computed by the backend control electronics (either by a software or hardware), but the pulse plotter is missing to emulate this behavior. I guess @taalexander needs this option also on the front end.

Usually we don't need to compute the frame on front end, however, if we need a custom frame such as qudit, we need to compute it. This is the relevant WIP PR: #5977. Apart from this, such frame computation might be necessary for the pulse simulator (it usually uses a rotating frame Hamiltonian to avoid fine dt step, and waveform in the lab frame is not necessary), or necessary for pulse plotter to understand how the waveform look like on the actual hardware.

Here, you need to add multiple transform passes that

Then you can create new transform that generates low-level waveform data.

nkanazawa1989 commented 3 years ago

If you add such transform pass, you can deprecate the pulse program loader tied to the plotter. Then, the plotter will become much simpler. https://github.com/Qiskit/qiskit-terra/blob/main/qiskit/visualization/pulse_v2/events.py

snsunx commented 3 years ago

Hi @nkanazawa1989, thank you for your detailed explanation. Now I can see that target_qobj_transform flattens a Schedule so that it can be used for later processing.

At the end should the function look roughly like this? sched and chan are the input variables.

sched_transformed = target_qobj_transform(sched)
chan_events = ChannelEvents.load_program(sched_transformed, chan)
waveform_inst_tups = chan_events.get_waveforms() # An iterator of InstructionTuples

chan_waveform = np.zeros((sched_trans.duration, ), dtype=complex)
for inst_tup in waveform_inst_tups:
    if isinstance(inst_tup.inst, Play):
        t0 = inst_tup.t0
        tf = t0 + inst_tup.inst.duration
        t_arr = np.arange(t0, tf)
        phase = inst_tup.frame.phase
        freq = inst_tup.frame.freq

        pulse_waveform = inst_tup.inst.pulse.get_waveform().samples
        pulse_waveform *= np.exp(1j * freq * t_arr * some_factor) # some_factor will be worked out later
        pulse_waveform *= np.exp(1j * phase)
        if apply_carrier_wave:
            pulse_waveform *= np.exp(1j * qubit_freq_est * t_arr * some_factor)
        chan_waveform[t0:tf] = pulse_waveform

I find that get_waveforms already calculates the the frequency and phase values of each pulse (the frame attribute), so they can just be used when applying the frequency and phase shift factors. The some_factor variable should depend on the unit conversions of ns and GHz. but I'll work out the exact expression later. If the above implementation looks about right I'll go on to write it into a script.

1ucian0 commented 2 years ago

Hi @snsunx, are you still working on this? I noticed you have been working on a branch that was never PRed. Let me know if I can help you!

snsunx commented 2 years ago

Hi @1ucian0, sorry I got busy with some other stuff and haven't been working on this. I unassigned myself and please feel free to assign this issue to others.

landamax commented 1 year ago

Hello everyone, This is really the first issue for me in Qiskit. I think that @snsunx did 99% of solving this issue. I would like to finish it. Can I be assigned to the issue? I already have some code, built upon @snsunx 's work, I'd like to be reviewed in case it solves the issue. If not - any comments will be appreciated.

nkanazawa1989 commented 1 year ago

Hi @landamax , please feel free to continue. Note that we are trying to refactor (and deprecate) the pulse transform module to promote it to the proper lowering operation in modern compiler. If you need merged waveform data, you can also use InstructionToSignal converter in qiskit dynamics.

snsunx commented 1 year ago

Hi @landamax, some of the work I did can be found on this branch in my forked Qiskit repository. Please let me know if there is anything I can help with