NVIDIA / cuQuantum

Home for cuQuantum Python & NVIDIA cuQuantum SDK C++ samples
https://docs.nvidia.com/cuda/cuquantum/
BSD 3-Clause "New" or "Revised" License
320 stars 63 forks source link

Mid-circuit measurements cause significant slowdown #138

Closed nathanieltornow closed 1 month ago

nathanieltornow commented 1 month ago

When using mid-circuit measurement(s), the cusvaer backend seems to get orders of magnitude slower. Is this expected behavior, or can I do something to speed up mid-circuit measurements?

To show this, here is a slightly modified version of the qiskit example.

from qiskit import QuantumCircuit, transpile
from qiskit import Aer
from mpi4py import MPI
import argparse
from time import perf_counter

def create_ghz_circuit(n_qubits, meas):
    circuit = QuantumCircuit(n_qubits, 1)
    circuit.h(0)
    for qubit in range(n_qubits - 1):
        circuit.cx(qubit, qubit + 1)
    for qubit in range(n_qubits - 1):
        circuit.cx(qubit, qubit + 1)

    if meas:
        circuit.measure(0, 0)

    for qubit in range(n_qubits - 1):
        circuit.cx(qubit, qubit + 1)
    for qubit in range(n_qubits - 1):
        circuit.cx(qubit, qubit + 1)
    return circuit

def run(n_qubits, precision, use_cusvaer, meas):
    simulator = Aer.get_backend('aer_simulator_statevector')
    simulator.set_option('cusvaer_enable', use_cusvaer)
    simulator.set_option('precision', precision)
    circuit = create_ghz_circuit(n_qubits, meas)
    circuit.measure_all()
    circuit = transpile(circuit, simulator)

    start = perf_counter()

    job = simulator.run(circuit)
    result = job.result()

    runtime = perf_counter() - start

    if MPI.COMM_WORLD.Get_rank() == 0:
        counts = result.get_counts()
        print(counts)
        print("time", runtime)

parser = argparse.ArgumentParser(description="Qiskit ghz.")
parser.add_argument('--nbits', type=int, default=20, help='the number of qubits')
parser.add_argument('--meas', default=False, action='store_true', help='the number of qubits')
parser.add_argument('--precision', type=str, default='single', choices=['single', 'double'], help='numerical precision')
parser.add_argument('--disable-cusvaer', default=False, action='store_true', help='disable cusvaer')

args = parser.parse_args()

run(args.nbits, args.precision, not args.disable_cusvaer, args.meas)

Without the mid-circuit measurement, it is really fast:

$ python qiskit_ghz.py
{'00000000000000000000 0': 510, '00010001000100010001 0': 514}
time 0.03272801800630987

However, with mid-circuit measurement:

$ python qiskit_ghz.py --meas
{'00010001000100010001 1': 496, '00000000000000000000 0': 528}
time 5.535552535904571

For 25 qubits:

$ python qiskit_ghz.py --nbits 25
{'1000100010001000100010001 0': 517, '0000000000000000000000000 0': 507}
time 0.05724717793054879
$ python qiskit_ghz.py --nbits 25 --meas
{'1000100010001000100010001 1': 528, '0000000000000000000000000 0': 496}
time 74.54424684913829

Specs:

ymagchi commented 1 month ago

Hi @nathanieltornow, Thank you for sharing the code and detailed configurations, it is helpful to us.

This behavior is expected. If measurements are performed only at the end of the simulation, the resulting state vector is consistent and we can reuse it for multiple sampling shots. On the other hand, for simulations with mid-circuit measurements, the simulation is executed once per shot because the state vector can be different after measuring it. In this case, the entire execution time will increase linearly to the number of shots.

The example code has an option --disable-cusvaer to use a different backend, in which we can see similar performance difference.

nathanieltornow commented 1 month ago

Hi @ymagchi, Thank you for your fast reply! This makes total sense.

Is it worth exploring a potential speed-up for this case? This could be, for example, achieved by computing a statevector for each possible mid-circuit measurement scenario and then computing the final counts using these statevectors. I saw this python-code that uses such an approach.

This would, however, mean having to compute $2^m$ statevectors where $m$ is the number of mid-circuit measurements. This could allow a significant speedup if $2^m < \text{shots}$. The property $2^m < \text{shots}$ could be checked before executing; if false we'd use the current, if true we'd use the proposed approach.

Let me know if you think this would be worth including in the code. If yes, I'd be happy to help :)

ymagchi commented 1 month ago

Is it worth exploring a potential speed-up for this case?

I think it works for your case above. For general cases, memory usage and data transfers to keep all state vectors will need to be taken into consideration.