PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.3k stars 591 forks source link

Slow performance when compiling (tf and jax-jit) gradients through `mitigate_with_zne` #2801

Open Qottmann opened 2 years ago

Qottmann commented 2 years ago

I noticed that jax-jit and tf compilation is very slow when computing gradients through the new mitigate_with_zne transform introduced in https://github.com/PennyLaneAI/pennylane/pull/2757. For an example with n_wires=6 it already takes around 5 minutes to compile the gradient. I did some scaling and found the compilation time to be roughly 0.78 n_wires**3.25. Just to give an idea, for 16 qubits that would be almost 2 hours for compilation.

mitigate_jax-jit-scaling

This means that in its current form and in practice, jitting is not usable. I wonder why that is and how we can fix that.

As a comparison, the scaling for the same example but without the mitigate_with_zne transform: 0.17 n_wires**2.79

mitigate_jax-jit-scaling_no_mitigate

What the transform does is roughly: execute an altered version of the qnode 3 times, and extrapolate from that. So I would expect something like a factor 3 constant time increase. I think I can narrow the problem down to the folding transform, because when I do the scaling with the mitigation transform but a dummy folding transform

def dummy_folding(tape, scale_factor):
    return tape

we get virtually the same result as without (so the fitting is not the problem):

mitigate_jax-jit-scaling_mitigate_with_dummy_folding

I am no expert in jax-jit, but the common denominator when searching performance problems seem to be: avoid for loops anywhere where possible. And unsurprisingly, qml.transforms.fold_global is one big nested for loop. What it does is: take the input tape, and then extend it by resolved identities. I.e. you give a tape $U=L_d ... L_1$, then fold_global will create

$$\text{fold}(U) = U (U^\dagger U)^n (L^\daggerd L^\dagger{d-1} .. L^\dagger_s) (L_s .. L_d)$$

It is doing so by looping through the elements and queuing them. I wonder,

The story is very similar in tf compiling, with similar compilation times. Here is the jax-jit example I was using to create the plots above:

n_wires = 6 # takes around 3-5 minutes

# Describe noise
noise_gate = qml.DepolarizingChannel
noise_strength = 0.1

# Load devices
dev_ideal = qml.device("default.mixed", wires=n_wires)
dev_noisy = qml.transforms.insert(noise_gate, noise_strength)(dev_ideal)

H = qml.Hamiltonian(coeffs=[1] * (n_wires-1), observables=[qml.PauliX(i) @ qml.PauliX(i+1) for i in range(n_wires-1)])
scale_factors = [1, 2, 3]

@mitigate_with_zne(scale_factors, fold_global, richardson_extrapolate)
@qml.qnode(dev_noisy, interface="jax-jit")
def qnode_mitigated(theta):
    for i in range(n_wires):
        qml.RY(theta[i], wires=i)
    for i in range(n_wires-1):
        qml.CNOT(wires=(i, i+1))
    for i in range(n_wires):
        qml.RY(theta[n_wires + i], wires=i)
    return qml.expval(H)

theta = jnp.arange(n_wires*2, dtype="float64")

t0 = time.time()
grad = jax.jit(jax.grad(qnode_mitigated))
grad(theta)
dt = time.time() - t0
print(f"Compiling gradient time: {dt} s")

%timeit grad(theta)
Qottmann commented 2 years ago

I traced the computation for the jit example: graph_jit-test

360 s

Qottmann commented 2 years ago

And here for a tf example: graph_tf-test

import time
import pennylane as qml
import pennylane.numpy as np
import jax
import jax.numpy as jnp
import tensorflow as tf

from matplotlib import pyplot as plt

from pennylane.transforms import mitigate_with_zne, fold_global, richardson_extrapolate

# tf compiling
n_wires = 6

# Describe noise
noise_gate = qml.DepolarizingChannel
noise_strength = 0.1

# Load devices
dev_ideal = qml.device("default.mixed", wires=n_wires)
dev_noisy = qml.transforms.insert(noise_gate, noise_strength)(dev_ideal)

H = qml.Hamiltonian(coeffs=[1] * (n_wires-1), observables=[qml.PauliX(i) @ qml.PauliX(i+1) for i in range(n_wires-1)])
scale_factors = [1, 2, 3]

#@mitigate_with_zne(scale_factors, fold_global, richardson_extrapolate)
@tf.function
@mitigate_with_zne(scale_factors, fold_global, richardson_extrapolate)
@qml.qnode(dev_noisy, interface="tf")
def qnode_mitigated(theta):
    qml.RY(theta[0], wires=0)
    qml.RY(theta[1], wires=1)
    qml.RY(theta[2], wires=2)
    qml.RY(theta[3], wires=3)
    qml.CNOT(wires=(0, 1))
    qml.CNOT(wires=(1, 2))
    qml.CNOT(wires=(2, 3))
    qml.RY(theta[4], wires=0)
    qml.RY(theta[5], wires=1)
    qml.RY(theta[6], wires=2)
    qml.RY(theta[7], wires=3)
    return qml.expval(H)

theta = tf.Variable(np.arange(n_wires*2), dtype="float64")

t0 = time.time()
with tf.GradientTape() as tape:
    res = qnode_mitigated(theta)
dt = time.time() - t0
print(f"Compiling gradient time: {dt} s")

t0 = time.time()
grad = tape.gradient(res, theta)
dt = time.time() - t0
print(f"Gradient execution time: {dt} s")
Compiling gradient time: 198.87679767608643 s
Gradient execution time: 9.8185133934021 s

I am actually not sure anything is compiling here because I get the warning:

WARNING:tensorflow:AutoGraph could not transform <function _gcd_import at 0x7f0e4d063430> and will run it as-is.
Cause: Unable to locate the source code of <function _gcd_import at 0x7f0e4d063430>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert