PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.28k stars 586 forks source link

[BUG] Taking gradients with `interface='jax'` or `interface='jax-python'` and backend=`qiskit.aer´ fails #2242

Closed quantshah closed 2 years ago

quantshah commented 2 years ago

Expected behavior

The code should compute the gradients with jax.grad using the qiskit.aer backend.

Interestingly, as @josh146 outlined, using @jit before the circuit or using the interface jax-jit instead of just jax works. This bug seems to be with the jax-python interface. You can change the interface to jax-jit or jax-python to check.

Actual behavior

Error: raise CircuitError(f"Invalid param type {type(parameter)} for gate {self.name}.") qiskit.circuit.exceptions.CircuitError: "Invalid param type <class 'jaxlib.xla_extension.DeviceArray'> for gate rz."

Additional information

No response

Source code

import pennylane as qml
from pennylane import numpy as np

import jax.numpy as jnp
from jax import grad
from jax import jit

N = 2
dev = qml.device("qiskit.aer", wires=N, shots=1000)

# WORKS with @jit or interface="jax-jit" but not with interface="jax-python"
@qml.qnode(dev, interface="jax")
def energy(params):
    qml.StronglyEntanglingLayers(params, wires=range(N))
    return qml.expval(qml.PauliZ(0))

params = np.random.random(qml.StronglyEntanglingLayers.shape(n_layers=2, n_wires=N))
params = jnp.array(params)
print("Energy:", energy(params))
print(grad(energy)(params))

print(qml.about())

Tracebacks

Traceback (most recent call last):
  File "/Users/shahnawaz/Dropbox/phd/implicit-diff/quantum-implicit-differentiation/tests/test_jax_qiskit.py", line 21, in <module>
    print(grad(energy)(params))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/qnode.py", line 560, in __call__
    res = qml.execute(
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 412, in execute
    res = _execute(
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/jax.py", line 110, in execute
    return _execute(
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/jax.py", line 258, in _execute
    return wrapped_exec(params)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: qiskit.circuit.exceptions.CircuitError: "Invalid param type <class 'jaxlib.xla_extension.DeviceArray'> for gate rz."

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/shahnawaz/Dropbox/phd/implicit-diff/quantum-implicit-differentiation/tests/test_jax_qiskit.py", line 21, in <module>
    print(grad(energy)(params))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/api.py", line 991, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/api.py", line 1073, in value_and_grad_f
    g = vjp_py(jax.lax._one(ans))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/tree_util.py", line 279, in __call__
    return self.fun(*args, **kw)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/api.py", line 2428, in _vjp_pullback_wrapper
    ans = fun(*args)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/tree_util.py", line 279, in __call__
    return self.fun(*args, **kw)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/interpreters/ad.py", line 123, in unbound_vjp
    arg_cts = backward_pass(jaxpr, reduce_axes, consts, dummy_args, cts)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/interpreters/ad.py", line 229, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/interpreters/ad.py", line 694, in _custom_lin_transpose
    cts_in = bwd.call_wrapped(*res, *cts_out)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/custom_derivatives.py", line 640, in <lambda>
    bwd_ = lu.wrap_init(lambda *args: bwd.call_wrapped(*args))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/jax.py", line 204, in wrapped_exec_bwd
    partial_res = execute_fn(vjp_tapes)[0]
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 173, in wrapper
    res = fn(execution_tapes.values(), **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 125, in fn
    return original_fn(tapes, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 426, in batch_execute
    compiled_circuits = self.compile_circuits(circuits)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 415, in compile_circuits
    self.create_circuit_object(circuit.operations, rotations=circuit.diagonalizing_gates)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 228, in create_circuit_object
    applied_operations = self.apply_operations(operations)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 289, in apply_operations
    gate = mapped_operation(*par)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/library/standard_gates/rz.py", line 61, in __init__
    super().__init__("rz", 1, [phi], label=label)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/gate.py", line 40, in __init__
    super().__init__(name, num_qubits, 0, params, label=label)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/instruction.py", line 100, in __init__
    self.params = params  # must be at last (other properties may be required for validation)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/instruction.py", line 216, in params
    self._params.append(self.validate_parameter(single_param))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/gate.py", line 241, in validate_parameter
    raise CircuitError(f"Invalid param type {type(parameter)} for gate {self.name}.")
jax._src.traceback_util.UnfilteredStackTrace: qiskit.circuit.exceptions.CircuitError: "Invalid param type <class 'jaxlib.xla_extension.DeviceArray'> for gate rz."

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/shahnawaz/Dropbox/phd/implicit-diff/quantum-implicit-differentiation/tests/test_jax_qiskit.py", line 21, in <module>
    print(grad(energy)(params))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/jax.py", line 204, in wrapped_exec_bwd
    partial_res = execute_fn(vjp_tapes)[0]
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 173, in wrapper
    res = fn(execution_tapes.values(), **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 125, in fn
    return original_fn(tapes, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 426, in batch_execute
    compiled_circuits = self.compile_circuits(circuits)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 415, in compile_circuits
    self.create_circuit_object(circuit.operations, rotations=circuit.diagonalizing_gates)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 228, in create_circuit_object
    applied_operations = self.apply_operations(operations)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 289, in apply_operations
    gate = mapped_operation(*par)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/library/standard_gates/rz.py", line 61, in __init__
    super().__init__("rz", 1, [phi], label=label)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/gate.py", line 40, in __init__
    super().__init__(name, num_qubits, 0, params, label=label)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/instruction.py", line 100, in __init__
    self.params = params  # must be at last (other properties may be required for validation)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/instruction.py", line 216, in params
    self._params.append(self.validate_parameter(single_param))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/gate.py", line 241, in validate_parameter
    raise CircuitError(f"Invalid param type {type(parameter)} for gate {self.name}.")
qiskit.circuit.exceptions.CircuitError: "Invalid param type <class 'jaxlib.xla_extension.DeviceArray'> for gate rz."
(mitiq) Shahnawazs-MacBook-Pro:tests shahnawaz$ python test_jax_qiskit.py 
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Energy: 0.242
Traceback (most recent call last):
  File "/Users/shahnawaz/Dropbox/phd/implicit-diff/quantum-implicit-differentiation/tests/test_jax_qiskit.py", line 21, in <module>
    print(grad(energy)(params))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/qnode.py", line 560, in __call__
    res = qml.execute(
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 412, in execute
    res = _execute(
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/jax.py", line 110, in execute
    return _execute(
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/jax.py", line 258, in _execute
    return wrapped_exec(params)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: qiskit.circuit.exceptions.CircuitError: "Invalid param type <class 'jaxlib.xla_extension.DeviceArray'> for gate rz."

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/shahnawaz/Dropbox/phd/implicit-diff/quantum-implicit-differentiation/tests/test_jax_qiskit.py", line 21, in <module>
    print(grad(energy)(params))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/api.py", line 991, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/api.py", line 1073, in value_and_grad_f
    g = vjp_py(jax.lax._one(ans))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/tree_util.py", line 279, in __call__
    return self.fun(*args, **kw)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/api.py", line 2428, in _vjp_pullback_wrapper
    ans = fun(*args)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/tree_util.py", line 279, in __call__
    return self.fun(*args, **kw)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/interpreters/ad.py", line 123, in unbound_vjp
    arg_cts = backward_pass(jaxpr, reduce_axes, consts, dummy_args, cts)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/interpreters/ad.py", line 229, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/interpreters/ad.py", line 694, in _custom_lin_transpose
    cts_in = bwd.call_wrapped(*res, *cts_out)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/custom_derivatives.py", line 640, in <lambda>
    bwd_ = lu.wrap_init(lambda *args: bwd.call_wrapped(*args))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/jax.py", line 204, in wrapped_exec_bwd
    partial_res = execute_fn(vjp_tapes)[0]
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 173, in wrapper
    res = fn(execution_tapes.values(), **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 125, in fn
    return original_fn(tapes, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 426, in batch_execute
    compiled_circuits = self.compile_circuits(circuits)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 415, in compile_circuits
    self.create_circuit_object(circuit.operations, rotations=circuit.diagonalizing_gates)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 228, in create_circuit_object
    applied_operations = self.apply_operations(operations)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 289, in apply_operations
    gate = mapped_operation(*par)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/library/standard_gates/rz.py", line 61, in __init__
    super().__init__("rz", 1, [phi], label=label)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/gate.py", line 40, in __init__
    super().__init__(name, num_qubits, 0, params, label=label)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/instruction.py", line 100, in __init__
    self.params = params  # must be at last (other properties may be required for validation)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/instruction.py", line 216, in params
    self._params.append(self.validate_parameter(single_param))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/gate.py", line 241, in validate_parameter
    raise CircuitError(f"Invalid param type {type(parameter)} for gate {self.name}.")
jax._src.traceback_util.UnfilteredStackTrace: qiskit.circuit.exceptions.CircuitError: "Invalid param type <class 'jaxlib.xla_extension.DeviceArray'> for gate rz."

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/shahnawaz/Dropbox/phd/implicit-diff/quantum-implicit-differentiation/tests/test_jax_qiskit.py", line 21, in <module>
    print(grad(energy)(params))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/jax.py", line 204, in wrapped_exec_bwd
    partial_res = execute_fn(vjp_tapes)[0]
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 173, in wrapper
    res = fn(execution_tapes.values(), **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 125, in fn
    return original_fn(tapes, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 426, in batch_execute
    compiled_circuits = self.compile_circuits(circuits)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 415, in compile_circuits
    self.create_circuit_object(circuit.operations, rotations=circuit.diagonalizing_gates)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 228, in create_circuit_object
    applied_operations = self.apply_operations(operations)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 289, in apply_operations
    gate = mapped_operation(*par)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/library/standard_gates/rz.py", line 61, in __init__
    super().__init__("rz", 1, [phi], label=label)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/gate.py", line 40, in __init__
    super().__init__(name, num_qubits, 0, params, label=label)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/instruction.py", line 100, in __init__
    self.params = params  # must be at last (other properties may be required for validation)
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/instruction.py", line 216, in params
    self._params.append(self.validate_parameter(single_param))
  File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/gate.py", line 241, in validate_parameter
    raise CircuitError(f"Invalid param type {type(parameter)} for gate {self.name}.")
qiskit.circuit.exceptions.CircuitError: "Invalid param type <class 'jaxlib.xla_extension.DeviceArray'> for gate rz."

System information

WARNING: pip is being invoked by an old script wrapper. This will fail in a future version of pip.
Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.
To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.
Name: PennyLane
Version: 0.21.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, retworkx, scipy, semantic-version, toml
Required-by: PennyLane-Lightning, PennyLane-qiskit
Platform info:           macOS-11.2.3-x86_64-i386-64bit
Python version:          3.9.10
Numpy version:           1.22.2
Scipy version:           1.7.3
Installed devices:
- lightning.qubit (PennyLane-Lightning-0.21.0)
- qiskit.aer (PennyLane-qiskit-0.21.0)
- qiskit.basicaer (PennyLane-qiskit-0.21.0)
- qiskit.ibmq (PennyLane-qiskit-0.21.0)
- qiskit.ibmq.circuit_runner (PennyLane-qiskit-0.21.0)
- qiskit.ibmq.sampler (PennyLane-qiskit-0.21.0)
- default.gaussian (PennyLane-0.21.0)
- default.mixed (PennyLane-0.21.0)
- default.qubit (PennyLane-0.21.0)
- default.qubit.autograd (PennyLane-0.21.0)
- default.qubit.jax (PennyLane-0.21.0)
- default.qubit.tf (PennyLane-0.21.0)
- default.qubit.torch (PennyLane-0.21.0)

Existing GitHub issues

CatalinaAlbornoz commented 2 years ago

Hi @quantshah, thank you for opening this issue. We've created a PR to fix this bug.