PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Apache License 2.0
2.38k stars 607 forks source link

[BUG] Performance digression of decomposition-based differentiation of `SpecialUnitary` #4635

Open dwierichs opened 1 year ago

dwierichs commented 1 year ago

Feature details

In #4585, the decomposition of TmpPauliRot, a helper object of SpecialUnitary, was changed in order to allow the new DefaultQubit device to differentiate SpecialUnitary. The decomposition-based differentiation pipeline assumed that the device's decomposition step before execution would happen after the trainable parameters are determined and corresponding gradient transforms are called. Unfortunately, this changes with the new device, so that now the zero-angle instances of TmpPauliRot are decomposed into zero-angle instances of PauliRot, which in turn make it into the execution pipeline. This makes the simulator execute a lot of identity operations, with causes some overhead (benchmark to be provided)

This affects differentiation of SpecialUnitary only, and only when using a differentiation method that uses the decomposition, like param_shift (as happens when using shot-based simulation).


I do not see a straight-forward solution at this point. Possible implementations rely on composability of gradient transforms envisioned for the mid-term future, or on a generalization of the generator property of operations.

Other ideas to implement the (non-backprop) differentiation of SpecialUnitary are very much welcome.

How important would you say this feature is?

2: Somewhat important. Needed this quarter.

Additional information

No response

josh146 commented 3 months ago

Hey @dwierichs, is this still an open issue?

josh146 commented 3 months ago

If not, @albi3ro will this be fixed by swapping where gradient transforms are done in relation to transforms/decompositions (your proposal for Q4?)

dwierichs commented 3 months ago

I looked at an example code on 4 qubits (for which SpecialUnitary has 255 parameters), code below.

The code (run in jupyter to use magic %prun)

import pennylane as qml

if interface=="jax":
    import jax
    jax.config.update("jax_enable_x64", True)
if interface=="torch":
    import torch
if interface=="tf":
    import tensorflow as tf

N = 4
wires = list(range(N))

@qml.qnode(qml.device("default.qubit"), diff_method="parameter-shift")
def node(x):
    qml.SpecialUnitary(x, wires)
    return qml.expval(qml.Z(0))

if interface=="jax":
    key = jax.random.PRNGKey(824)
    x = jax.random.uniform(key, (4**N - 1,))
    %prun jax.grad(node)(x)
if interface=="torch":
    x = torch.rand(4**N - 1, requires_grad=True)
    out = node(x)
    %prun out.backward(retain_graph=True)
if interface=="tf":
    x = tf.Variable(tf.random.uniform(4**N - 1))
    with tf.GradientTape() as tape:
        out = node(x)
    %prun tape.gradient(out, x)