Open dwierichs opened 1 year ago
Hey @dwierichs, is this still an open issue?
If not, @albi3ro will this be fixed by swapping where gradient transforms are done in relation to transforms/decompositions (your proposal for Q4?)
I looked at an example code on 4 qubits (for which SpecialUnitary
has 255 parameters), code below.
jax.grad
? I'm not entirely sure, but the compile time/first run gets sped up (12.4s -> ~12s
) by manually skipping these ops (in apply_operation
) whereas the execution time did not really change (is at about 4.3s
)3.2%
faster if we manually skip identity ops (64.34s->62.25s
)~5.1s
)The code (run in jupyter to use magic %prun
)
import pennylane as qml
if interface=="jax":
import jax
jax.config.update("jax_enable_x64", True)
if interface=="torch":
import torch
if interface=="tf":
import tensorflow as tf
N = 4
wires = list(range(N))
@qml.qnode(qml.device("default.qubit"), diff_method="parameter-shift")
def node(x):
qml.SpecialUnitary(x, wires)
return qml.expval(qml.Z(0))
if interface=="jax":
key = jax.random.PRNGKey(824)
x = jax.random.uniform(key, (4**N - 1,))
%prun jax.grad(node)(x)
if interface=="torch":
x = torch.rand(4**N - 1, requires_grad=True)
out = node(x)
%prun out.backward(retain_graph=True)
if interface=="tf":
x = tf.Variable(tf.random.uniform(4**N - 1))
with tf.GradientTape() as tape:
out = node(x)
%prun tape.gradient(out, x)
Feature details
In #4585, the decomposition of
TmpPauliRot
, a helper object ofSpecialUnitary
, was changed in order to allow the newDefaultQubit
device to differentiateSpecialUnitary
. The decomposition-based differentiation pipeline assumed that the device's decomposition step before execution would happen after the trainable parameters are determined and corresponding gradient transforms are called. Unfortunately, this changes with the new device, so that now the zero-angle instances ofTmpPauliRot
are decomposed into zero-angle instances ofPauliRot
, which in turn make it into the execution pipeline. This makes the simulator execute a lot of identity operations, with causes some overhead (benchmark to be provided)This affects differentiation of
SpecialUnitary
only, and only when using a differentiation method that uses the decomposition, likeparam_shift
(as happens when using shot-based simulation).Implementation
I do not see a straight-forward solution at this point. Possible implementations rely on composability of gradient transforms envisioned for the mid-term future, or on a generalization of the
generator
property of operations.Other ideas to implement the (non-backprop) differentiation of
SpecialUnitary
are very much welcome.How important would you say this feature is?
2: Somewhat important. Needed this quarter.
Additional information
No response