PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.37k stars 605 forks source link

Support hardware differentiation for broadcasted `SpecialUnitary` #3721

Open dwierichs opened 1 year ago

dwierichs commented 1 year ago

Feature details

Currently, SpecialUnitary supports broadcasting and is auto-differentiable when it is broadcasted.

In addition #3674 adds differentiation with transforms to SpecialUnitary.

However, the two features can not be combined yet, i.e. gradient transforms can not be applied to a broadcasted SpecialUnitary.

Implementation

In order to support this feature, the decomposition method used by SpecialUnitary to become transform-differentiable needs to be made compatible with broadcasting. Likely the main work is in get_one_parameter_generators and get_one_parameter_coeffs, overall it should mostly be about making the right case separations and choosing tensor manipulation axes correctly.

How important would you say this feature is?

1: Not important. Would be nice to have.

Additional information

No response

dwierichs commented 1 year ago

Comment upon revisiting this: Hardware-compatible differentiation already works with broadcasting if it is done via jax.vmap, and the isclose check in TmpPauliRot is changed to allclose

@qml.qnode(qml.device("default.qubit", wires=1), interface="jax", diff_method="parameter-shift")
def node(x):
    qml.SpecialUnitary(x, [0])
    return qml.expval(Z(0))

x = jnp.array([0.5, 0.2, 1.2])
print(node(x))
print(jax.grad(node)(x))
y = jnp.array([[0.3, 0.2, 0.1], [0.5, 1.2, -0.6]]*3)
vmap_node = jax.vmap(node)
print(vmap_node(y))
print(jax.vmap(jax.jacobian(node))(y))