Closed dwierichs closed 1 year ago
The argnum
keyword argument is used in several transforms throughout PennyLane, such as
qml.gradients.param_shift
in pennylane/gradients/parameter_shift.py
and qml.gradients.finite_diff
in pennylane/gradients/finite_difference.py
.qml.transforms.classical_jacobian
.Note: In this issue, the metric_tensor
function that needs modification is a batch_transform
, which has a similar behaviour as gradient_transform
s. Therefore, it will be sufficient and consistent to consider tape arguments only when implementing argnum
. This is elaborated in detail below.
Note: An implementation with the new return type system is not necessary within the first implementation, it can be separated out to a second pull request.
Note: This issue can be further broken down into implementing argnum
for the internal functions _metric_tensor_cov_matrix
and _metric_tensor_hadamard
separately.
argnum
can be slightly confusing.It should be noted that the first transforms in the list above are gradient_transform
s, which have the property to be coded for a function that modifies QuantumTape
s and is automatically turned into a function that modifies QNode
s via the gradient_transform
decorator. Due to this automatic conversion, the argnum
argument of these transforms refers to tape arguments, not to QNode
arguments, even if the transform is applied to a QNode. An example:
dev = qml.device("default.qubit", wires=2)
x = np.array(0.4, requires_grad=True)
y = np.array(0.1, requires_grad=True)
@qml.qnode(dev)
def circuit(x, y):
qml.RX(y, 0)
qml.RY(x, 1)
return qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))
This QNode
returns two expectation values, the first depending on y
and the second on x
(this "twist" is intentional).
Now if we compute the parameter-shift gradient with argnum=0
, we get non-zero entries in the second part of the gradient, belonging to y
:
>>> qml.gradients.param_shift(circuit, argnum=0)(x, y)
(tensor([0., 0.], requires_grad=True),
tensor([-9.98334166e-02, -5.55111512e-17], requires_grad=True))
This is because with argnum=0
we refer to the first tape parameter in RX(y, 0)
, not to the first QNode
parameter in circuit(x, y)
.
In this issue, the metric_tensor
function that needs modification is a batch_transform
, which has this same behaviour. Therefore, it will be sufficient and consistent to consider tape arguments only when implementing argnum
.
Hi there, I just sent a suggestion in this PR for an implementation of the argnum
argument. Let me know what you think.
Feature details
Adding
argnums
to transforms allows us to compute quantities in the forward pass while in an autodiff framework. Similar toqml.grad
orqml.classical_jacobian
, it should be added toqml.metric_tensor
as well.Implementation
No response
How important would you say this feature is?
2: Somewhat important. Needed this quarter.
Additional information
No response