Open dwierichs opened 1 year ago
Comment upon revisiting this: Hardware-compatible differentiation already works with broadcasting if it is done via jax.vmap
, and the isclose
check in TmpPauliRot
is changed to allclose
@qml.qnode(qml.device("default.qubit", wires=1), interface="jax", diff_method="parameter-shift")
def node(x):
qml.SpecialUnitary(x, [0])
return qml.expval(Z(0))
x = jnp.array([0.5, 0.2, 1.2])
print(node(x))
print(jax.grad(node)(x))
y = jnp.array([[0.3, 0.2, 0.1], [0.5, 1.2, -0.6]]*3)
vmap_node = jax.vmap(node)
print(vmap_node(y))
print(jax.vmap(jax.jacobian(node))(y))
Feature details
Currently,
SpecialUnitary
supports broadcasting and is auto-differentiable when it is broadcasted.In addition #3674 adds differentiation with transforms to
SpecialUnitary
.However, the two features can not be combined yet, i.e. gradient transforms can not be applied to a broadcasted
SpecialUnitary
.Implementation
In order to support this feature, the decomposition method used by
SpecialUnitary
to become transform-differentiable needs to be made compatible with broadcasting. Likely the main work is inget_one_parameter_generators
andget_one_parameter_coeffs
, overall it should mostly be about making the right case separations and choosing tensor manipulation axes correctly.How important would you say this feature is?
1: Not important. Would be nice to have.
Additional information
No response