Open Qottmann opened 2 years ago
I traced the computation for the jit example: graph_jit-test
360 s
And here for a tf
example:
graph_tf-test
import time
import pennylane as qml
import pennylane.numpy as np
import jax
import jax.numpy as jnp
import tensorflow as tf
from matplotlib import pyplot as plt
from pennylane.transforms import mitigate_with_zne, fold_global, richardson_extrapolate
# tf compiling
n_wires = 6
# Describe noise
noise_gate = qml.DepolarizingChannel
noise_strength = 0.1
# Load devices
dev_ideal = qml.device("default.mixed", wires=n_wires)
dev_noisy = qml.transforms.insert(noise_gate, noise_strength)(dev_ideal)
H = qml.Hamiltonian(coeffs=[1] * (n_wires-1), observables=[qml.PauliX(i) @ qml.PauliX(i+1) for i in range(n_wires-1)])
scale_factors = [1, 2, 3]
#@mitigate_with_zne(scale_factors, fold_global, richardson_extrapolate)
@tf.function
@mitigate_with_zne(scale_factors, fold_global, richardson_extrapolate)
@qml.qnode(dev_noisy, interface="tf")
def qnode_mitigated(theta):
qml.RY(theta[0], wires=0)
qml.RY(theta[1], wires=1)
qml.RY(theta[2], wires=2)
qml.RY(theta[3], wires=3)
qml.CNOT(wires=(0, 1))
qml.CNOT(wires=(1, 2))
qml.CNOT(wires=(2, 3))
qml.RY(theta[4], wires=0)
qml.RY(theta[5], wires=1)
qml.RY(theta[6], wires=2)
qml.RY(theta[7], wires=3)
return qml.expval(H)
theta = tf.Variable(np.arange(n_wires*2), dtype="float64")
t0 = time.time()
with tf.GradientTape() as tape:
res = qnode_mitigated(theta)
dt = time.time() - t0
print(f"Compiling gradient time: {dt} s")
t0 = time.time()
grad = tape.gradient(res, theta)
dt = time.time() - t0
print(f"Gradient execution time: {dt} s")
Compiling gradient time: 198.87679767608643 s
Gradient execution time: 9.8185133934021 s
I am actually not sure anything is compiling here because I get the warning:
WARNING:tensorflow:AutoGraph could not transform <function _gcd_import at 0x7f0e4d063430> and will run it as-is.
Cause: Unable to locate the source code of <function _gcd_import at 0x7f0e4d063430>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
I noticed that
jax-jit
andtf
compilation is very slow when computing gradients through the newmitigate_with_zne
transform introduced in https://github.com/PennyLaneAI/pennylane/pull/2757. For an example withn_wires=6
it already takes around 5 minutes to compile the gradient. I did some scaling and found the compilation time to be roughly0.78 n_wires**3.25
. Just to give an idea, for 16 qubits that would be almost 2 hours for compilation.This means that in its current form and in practice, jitting is not usable. I wonder why that is and how we can fix that.
As a comparison, the scaling for the same example but without the
mitigate_with_zne
transform:0.17 n_wires**2.79
What the transform does is roughly: execute an altered version of the qnode 3 times, and extrapolate from that. So I would expect something like a factor 3 constant time increase. I think I can narrow the problem down to the folding transform, because when I do the scaling with the mitigation transform but a dummy folding transform
we get virtually the same result as without (so the fitting is not the problem):
I am no expert in
jax-jit
, but the common denominator when searching performance problems seem to be: avoid for loops anywhere where possible. And unsurprisingly,qml.transforms.fold_global
is one big nested for loop. What it does is: take the input tape, and then extend it by resolved identities. I.e. you give a tape $U=L_d ... L_1$, thenfold_global
will create$$\text{fold}(U) = U (U^\dagger U)^n (L^\daggerd L^\dagger{d-1} .. L^\dagger_s) (L_s .. L_d)$$
It is doing so by looping through the elements and queuing them. I wonder,
The story is very similar in
tf
compiling, with similar compilation times. Here is thejax-jit
example I was using to create the plots above: