Open dwierichs opened 2 years ago
@dwierichs thanks for finding this! I'm just getting caught up after vacation, and I'm a little out of the loop, so just two quick questions:
argnum
argument?Hi @josh146
argnum
...)Which fix do you mean here? I think we don't have one for JAX (as the forward passes don't work without argnum...)
Oh, I had a recollection that classical_jacobian
was patched for JAX to only support non-trainable parameters? I could be wrong though!
Ah, yes, kind of :D By now this is patched again and both,trainable_only=True
and False
are supported again.
In any case I feel like adding argnum
is the better way forward here because it avoids patchwork treatment of the metric tensor behaviour for different interfaces?
On the side, I noticed that the same bug exists for JAX.
On the side, I noticed that the same bug exists for JAX.
oof 😬
I am in agreement, lets prioritize argnum
This bug is fixed via argnum
for JAX, but not for Torch.
Expected behavior
In the following QNode, we have two arguments:
The expected behaviour for the
metric_tensor
now would be a single array output with the metric tensor w.r.ty
, namelyActual behavior
The computation raises an error (see below). This is as far as I can tell because the classical Jacobian is computed for all QNode in arguments in Torch. That is, without
argnum
as keyword, the QNode wrapper does not know which of the classical Jacobians to use to contract with the tape metric tensor(s). Another reason to implement #1880 soonish.Indeed, when setting
hybrid=False
inmetric_tensor
, we get the correct metric tensor for the three gates that are controlled by entries of the QNode argumenty
:Additional information
No response
Source code
No response
Tracebacks
System information