PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
The code should compute the gradients with jax.grad using the qiskit.aer backend.
Interestingly, as @josh146 outlined, using @jit before the circuit or using the interface jax-jit instead of just jax works. This bug seems to be with the jax-python interface. You can change the interface to jax-jit or jax-python to check.
Actual behavior
Error:
raise CircuitError(f"Invalid param type {type(parameter)} for gate {self.name}.") qiskit.circuit.exceptions.CircuitError: "Invalid param type <class 'jaxlib.xla_extension.DeviceArray'> for gate rz."
Additional information
No response
Source code
import pennylane as qml
from pennylane import numpy as np
import jax.numpy as jnp
from jax import grad
from jax import jit
N = 2
dev = qml.device("qiskit.aer", wires=N, shots=1000)
# WORKS with @jit or interface="jax-jit" but not with interface="jax-python"
@qml.qnode(dev, interface="jax")
def energy(params):
qml.StronglyEntanglingLayers(params, wires=range(N))
return qml.expval(qml.PauliZ(0))
params = np.random.random(qml.StronglyEntanglingLayers.shape(n_layers=2, n_wires=N))
params = jnp.array(params)
print("Energy:", energy(params))
print(grad(energy)(params))
print(qml.about())
Tracebacks
Traceback (most recent call last):
File "/Users/shahnawaz/Dropbox/phd/implicit-diff/quantum-implicit-differentiation/tests/test_jax_qiskit.py", line 21, in <module>
print(grad(energy)(params))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/qnode.py", line 560, in __call__
res = qml.execute(
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 412, in execute
res = _execute(
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/jax.py", line 110, in execute
return _execute(
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/jax.py", line 258, in _execute
return wrapped_exec(params)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: qiskit.circuit.exceptions.CircuitError: "Invalid param type <class 'jaxlib.xla_extension.DeviceArray'> for gate rz."
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/shahnawaz/Dropbox/phd/implicit-diff/quantum-implicit-differentiation/tests/test_jax_qiskit.py", line 21, in <module>
print(grad(energy)(params))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/api.py", line 991, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/api.py", line 1073, in value_and_grad_f
g = vjp_py(jax.lax._one(ans))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/tree_util.py", line 279, in __call__
return self.fun(*args, **kw)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/api.py", line 2428, in _vjp_pullback_wrapper
ans = fun(*args)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/tree_util.py", line 279, in __call__
return self.fun(*args, **kw)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/interpreters/ad.py", line 123, in unbound_vjp
arg_cts = backward_pass(jaxpr, reduce_axes, consts, dummy_args, cts)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/interpreters/ad.py", line 229, in backward_pass
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/interpreters/ad.py", line 694, in _custom_lin_transpose
cts_in = bwd.call_wrapped(*res, *cts_out)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/custom_derivatives.py", line 640, in <lambda>
bwd_ = lu.wrap_init(lambda *args: bwd.call_wrapped(*args))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/jax.py", line 204, in wrapped_exec_bwd
partial_res = execute_fn(vjp_tapes)[0]
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 173, in wrapper
res = fn(execution_tapes.values(), **kwargs)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 125, in fn
return original_fn(tapes, **kwargs)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 426, in batch_execute
compiled_circuits = self.compile_circuits(circuits)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 415, in compile_circuits
self.create_circuit_object(circuit.operations, rotations=circuit.diagonalizing_gates)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 228, in create_circuit_object
applied_operations = self.apply_operations(operations)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 289, in apply_operations
gate = mapped_operation(*par)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/library/standard_gates/rz.py", line 61, in __init__
super().__init__("rz", 1, [phi], label=label)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/gate.py", line 40, in __init__
super().__init__(name, num_qubits, 0, params, label=label)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/instruction.py", line 100, in __init__
self.params = params # must be at last (other properties may be required for validation)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/instruction.py", line 216, in params
self._params.append(self.validate_parameter(single_param))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/gate.py", line 241, in validate_parameter
raise CircuitError(f"Invalid param type {type(parameter)} for gate {self.name}.")
jax._src.traceback_util.UnfilteredStackTrace: qiskit.circuit.exceptions.CircuitError: "Invalid param type <class 'jaxlib.xla_extension.DeviceArray'> for gate rz."
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/shahnawaz/Dropbox/phd/implicit-diff/quantum-implicit-differentiation/tests/test_jax_qiskit.py", line 21, in <module>
print(grad(energy)(params))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/jax.py", line 204, in wrapped_exec_bwd
partial_res = execute_fn(vjp_tapes)[0]
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 173, in wrapper
res = fn(execution_tapes.values(), **kwargs)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 125, in fn
return original_fn(tapes, **kwargs)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 426, in batch_execute
compiled_circuits = self.compile_circuits(circuits)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 415, in compile_circuits
self.create_circuit_object(circuit.operations, rotations=circuit.diagonalizing_gates)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 228, in create_circuit_object
applied_operations = self.apply_operations(operations)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 289, in apply_operations
gate = mapped_operation(*par)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/library/standard_gates/rz.py", line 61, in __init__
super().__init__("rz", 1, [phi], label=label)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/gate.py", line 40, in __init__
super().__init__(name, num_qubits, 0, params, label=label)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/instruction.py", line 100, in __init__
self.params = params # must be at last (other properties may be required for validation)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/instruction.py", line 216, in params
self._params.append(self.validate_parameter(single_param))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/gate.py", line 241, in validate_parameter
raise CircuitError(f"Invalid param type {type(parameter)} for gate {self.name}.")
qiskit.circuit.exceptions.CircuitError: "Invalid param type <class 'jaxlib.xla_extension.DeviceArray'> for gate rz."
(mitiq) Shahnawazs-MacBook-Pro:tests shahnawaz$ python test_jax_qiskit.py
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Energy: 0.242
Traceback (most recent call last):
File "/Users/shahnawaz/Dropbox/phd/implicit-diff/quantum-implicit-differentiation/tests/test_jax_qiskit.py", line 21, in <module>
print(grad(energy)(params))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/qnode.py", line 560, in __call__
res = qml.execute(
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 412, in execute
res = _execute(
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/jax.py", line 110, in execute
return _execute(
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/jax.py", line 258, in _execute
return wrapped_exec(params)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: qiskit.circuit.exceptions.CircuitError: "Invalid param type <class 'jaxlib.xla_extension.DeviceArray'> for gate rz."
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/shahnawaz/Dropbox/phd/implicit-diff/quantum-implicit-differentiation/tests/test_jax_qiskit.py", line 21, in <module>
print(grad(energy)(params))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/api.py", line 991, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/api.py", line 1073, in value_and_grad_f
g = vjp_py(jax.lax._one(ans))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/tree_util.py", line 279, in __call__
return self.fun(*args, **kw)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/api.py", line 2428, in _vjp_pullback_wrapper
ans = fun(*args)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/tree_util.py", line 279, in __call__
return self.fun(*args, **kw)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/interpreters/ad.py", line 123, in unbound_vjp
arg_cts = backward_pass(jaxpr, reduce_axes, consts, dummy_args, cts)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/interpreters/ad.py", line 229, in backward_pass
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/interpreters/ad.py", line 694, in _custom_lin_transpose
cts_in = bwd.call_wrapped(*res, *cts_out)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/_src/custom_derivatives.py", line 640, in <lambda>
bwd_ = lu.wrap_init(lambda *args: bwd.call_wrapped(*args))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/jax.py", line 204, in wrapped_exec_bwd
partial_res = execute_fn(vjp_tapes)[0]
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 173, in wrapper
res = fn(execution_tapes.values(), **kwargs)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 125, in fn
return original_fn(tapes, **kwargs)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 426, in batch_execute
compiled_circuits = self.compile_circuits(circuits)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 415, in compile_circuits
self.create_circuit_object(circuit.operations, rotations=circuit.diagonalizing_gates)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 228, in create_circuit_object
applied_operations = self.apply_operations(operations)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 289, in apply_operations
gate = mapped_operation(*par)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/library/standard_gates/rz.py", line 61, in __init__
super().__init__("rz", 1, [phi], label=label)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/gate.py", line 40, in __init__
super().__init__(name, num_qubits, 0, params, label=label)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/instruction.py", line 100, in __init__
self.params = params # must be at last (other properties may be required for validation)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/instruction.py", line 216, in params
self._params.append(self.validate_parameter(single_param))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/gate.py", line 241, in validate_parameter
raise CircuitError(f"Invalid param type {type(parameter)} for gate {self.name}.")
jax._src.traceback_util.UnfilteredStackTrace: qiskit.circuit.exceptions.CircuitError: "Invalid param type <class 'jaxlib.xla_extension.DeviceArray'> for gate rz."
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/shahnawaz/Dropbox/phd/implicit-diff/quantum-implicit-differentiation/tests/test_jax_qiskit.py", line 21, in <module>
print(grad(energy)(params))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/jax.py", line 204, in wrapped_exec_bwd
partial_res = execute_fn(vjp_tapes)[0]
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 173, in wrapper
res = fn(execution_tapes.values(), **kwargs)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane/interfaces/batch/__init__.py", line 125, in fn
return original_fn(tapes, **kwargs)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 426, in batch_execute
compiled_circuits = self.compile_circuits(circuits)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 415, in compile_circuits
self.create_circuit_object(circuit.operations, rotations=circuit.diagonalizing_gates)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 228, in create_circuit_object
applied_operations = self.apply_operations(operations)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/pennylane_qiskit/qiskit_device.py", line 289, in apply_operations
gate = mapped_operation(*par)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/library/standard_gates/rz.py", line 61, in __init__
super().__init__("rz", 1, [phi], label=label)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/gate.py", line 40, in __init__
super().__init__(name, num_qubits, 0, params, label=label)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/instruction.py", line 100, in __init__
self.params = params # must be at last (other properties may be required for validation)
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/instruction.py", line 216, in params
self._params.append(self.validate_parameter(single_param))
File "/Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages/qiskit/circuit/gate.py", line 241, in validate_parameter
raise CircuitError(f"Invalid param type {type(parameter)} for gate {self.name}.")
qiskit.circuit.exceptions.CircuitError: "Invalid param type <class 'jaxlib.xla_extension.DeviceArray'> for gate rz."
System information
WARNING: pip is being invoked by an old script wrapper. This will fail in a future version of pip.
Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.
To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.
Name: PennyLane
Version: 0.21.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /Users/shahnawaz/miniconda3/envs/mitiq/lib/python3.9/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, retworkx, scipy, semantic-version, toml
Required-by: PennyLane-Lightning, PennyLane-qiskit
Platform info: macOS-11.2.3-x86_64-i386-64bit
Python version: 3.9.10
Numpy version: 1.22.2
Scipy version: 1.7.3
Installed devices:
- lightning.qubit (PennyLane-Lightning-0.21.0)
- qiskit.aer (PennyLane-qiskit-0.21.0)
- qiskit.basicaer (PennyLane-qiskit-0.21.0)
- qiskit.ibmq (PennyLane-qiskit-0.21.0)
- qiskit.ibmq.circuit_runner (PennyLane-qiskit-0.21.0)
- qiskit.ibmq.sampler (PennyLane-qiskit-0.21.0)
- default.gaussian (PennyLane-0.21.0)
- default.mixed (PennyLane-0.21.0)
- default.qubit (PennyLane-0.21.0)
- default.qubit.autograd (PennyLane-0.21.0)
- default.qubit.jax (PennyLane-0.21.0)
- default.qubit.tf (PennyLane-0.21.0)
- default.qubit.torch (PennyLane-0.21.0)
Existing GitHub issues
[X] I have searched existing GitHub issues to make sure the issue does not already exist.
Expected behavior
The code should compute the gradients with jax.grad using the
qiskit.aer
backend.Interestingly, as @josh146 outlined, using @jit before the circuit or using the interface
jax-jit
instead of justjax
works. This bug seems to be with thejax-python
interface. You can change the interface to jax-jit or jax-python to check.Actual behavior
Error:
raise CircuitError(f"Invalid param type {type(parameter)} for gate {self.name}.") qiskit.circuit.exceptions.CircuitError: "Invalid param type <class 'jaxlib.xla_extension.DeviceArray'> for gate rz."
Additional information
No response
Source code
Tracebacks
System information
Existing GitHub issues