PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
I expect that BasisState, when used with jit, should behave nicely like BasisStatePreparation does because BasisState decomposes to BasisStatePreparation.
Actual behavior
BasisState isn't jit friendly with default.qubit
Additional information
No response
Source code
import pennylane as qml
import jax
from jax import numpy as jnp
dev = qml.device("default.qubit", wires=3)
@jax.jit
@qml.qnode(dev)
def circuit_BasisState(n):
qml.BasisState(n, wires) # doesn't work
#qml.BasisStatePreparation(n, wires) # works
return qml.state()
n = jnp.array([0, 1, 1])
print(qml.draw(circuit_BasisState)(n))
Tracebacks
---------------------------------------------------------------------------
TracerBoolConversionError Traceback (most recent call last)
Cell In[64], line 19
16 n = jnp.array([0, 1, 1])
18 #print(qml.draw(circuit_BasisEmbedding)(n))
---> 19 print(qml.draw(circuit_BasisState)(n))
File ~/Documents/pennylane/pennylane/drawer/draw.py:304, in draw.<locals>.wrapper(*args, **kwargs)
302 @wraps(qnode)
303 def wrapper(*args, **kwargs):
--> 304 tape = qml.tape.make_qscript(qnode)(*args, **kwargs)
306 if wire_order:
307 _wire_order = wire_order
File ~/Documents/pennylane/pennylane/tape/qscript.py:1298, in make_qscript.<locals>.wrapper(*args, **kwargs)
1296 def wrapper(*args, **kwargs):
1297 with AnnotatedQueue() as q:
-> 1298 fn(*args, **kwargs)
1300 return QuantumScript.from_queue(q, shots)
[... skipping hidden 12 frame]
File ~/Documents/pennylane/pennylane/workflow/qnode.py:1164, in QNode.__call__(self, *args, **kwargs)
1162 if qml.capture.enabled():
1163 return qml.capture.qnode_call(self, *args, **kwargs)
-> 1164 return self._impl_call(*args, **kwargs)
File ~/Documents/pennylane/pennylane/workflow/qnode.py:1150, in QNode._impl_call(self, *args, **kwargs)
1147 self._update_gradient_fn(shots=override_shots, tape=self._tape)
1149 try:
-> 1150 res = self._execution_component(args, kwargs, override_shots=override_shots)
1151 finally:
1152 if old_interface == "auto":
File ~/Documents/pennylane/pennylane/workflow/qnode.py:1103, in QNode._execution_component(self, args, kwargs, override_shots)
1100 _prune_dynamic_transform(full_transform_program, inner_transform_program)
1102 # pylint: disable=unexpected-keyword-arg
-> 1103 res = qml.execute(
1104 (self._tape,),
1105 device=self.device,
1106 gradient_fn=self.gradient_fn,
1107 interface=self.interface,
1108 transform_program=full_transform_program,
1109 inner_transform=inner_transform_program,
1110 config=config,
1111 gradient_kwargs=self.gradient_kwargs,
1112 override_shots=override_shots,
1113 **self.execute_kwargs,
1114 )
1115 res = res[0]
1117 # convert result to the interface in case the qfunc has no parameters
File ~/Documents/pennylane/pennylane/workflow/execution.py:666, in execute(tapes, device, gradient_fn, interface, transform_program, inner_transform, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform, device_vjp, mcm_config)
664 # Exiting early if we do not need to deal with an interface boundary
665 if no_interface_boundary_required:
--> 666 results = inner_execute(tapes)
667 return post_processing(results)
669 _grad_on_execution = False
File ~/Documents/pennylane/pennylane/workflow/execution.py:316, in _make_inner_execute.<locals>.inner_execute(tapes, **_)
313 transformed_tapes = tuple(expand_fn(t) for t in transformed_tapes)
315 if transformed_tapes:
--> 316 results = device_execution(transformed_tapes)
317 else:
318 results = ()
File ~/Documents/pennylane/pennylane/devices/modifiers/simulator_tracking.py:30, in _track_execute.<locals>.execute(self, circuits, execution_config)
28 @wraps(untracked_execute)
29 def execute(self, circuits, execution_config=DefaultExecutionConfig):
---> 30 results = untracked_execute(self, circuits, execution_config)
31 if isinstance(circuits, QuantumScript):
32 batch = (circuits,)
File ~/Documents/pennylane/pennylane/devices/modifiers/single_tape_support.py:32, in _make_execute.<locals>.execute(self, circuits, execution_config)
30 is_single_circuit = True
31 circuits = (circuits,)
---> 32 results = batch_execute(self, circuits, execution_config)
33 return results[0] if is_single_circuit else results
File ~/Documents/pennylane/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/Documents/pennylane/pennylane/devices/default_qubit.py:597, in DefaultQubit.execute(self, circuits, execution_config)
594 prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))]
596 if max_workers is None:
--> 597 return tuple(
598 _simulate_wrapper(
599 c,
600 {
601 "rng": self._rng,
602 "debugger": self._debugger,
603 "interface": interface,
604 "state_cache": self._state_cache,
605 "prng_key": _key,
606 "mcm_method": execution_config.mcm_config.mcm_method,
607 "postselect_mode": execution_config.mcm_config.postselect_mode,
608 },
609 )
610 for c, _key in zip(circuits, prng_keys)
611 )
613 vanilla_circuits = convert_to_numpy_parameters(circuits)[0]
614 seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))
File ~/Documents/pennylane/pennylane/devices/default_qubit.py:598, in <genexpr>(.0)
594 prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))]
596 if max_workers is None:
597 return tuple(
--> 598 _simulate_wrapper(
599 c,
600 {
601 "rng": self._rng,
602 "debugger": self._debugger,
603 "interface": interface,
604 "state_cache": self._state_cache,
605 "prng_key": _key,
606 "mcm_method": execution_config.mcm_config.mcm_method,
607 "postselect_mode": execution_config.mcm_config.postselect_mode,
608 },
609 )
610 for c, _key in zip(circuits, prng_keys)
611 )
613 vanilla_circuits = convert_to_numpy_parameters(circuits)[0]
614 seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))
File ~/Documents/pennylane/pennylane/devices/default_qubit.py:863, in _simulate_wrapper(circuit, kwargs)
862 def _simulate_wrapper(circuit, kwargs):
--> 863 return simulate(circuit, **kwargs)
File ~/Documents/pennylane/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/Documents/pennylane/pennylane/devices/qubit/simulate.py:354, in simulate(circuit, debugger, state_cache, **execution_kwargs)
351 return tuple(results)
353 ops_key, meas_key = jax_random_split(prng_key)
--> 354 state, is_state_batched = get_final_state(
355 circuit, debugger=debugger, prng_key=ops_key, **execution_kwargs
356 )
357 if state_cache is not None:
358 state_cache[circuit.hash] = state
File ~/Documents/pennylane/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
54 s_caller = "::L".join(
55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
56 )
57 lgr.debug(
58 f"Calling {f_string} from {s_caller}",
59 **_debug_log_kwargs,
60 )
---> 61 return func(*args, **kwargs)
File ~/Documents/pennylane/pennylane/devices/qubit/simulate.py:165, in get_final_state(circuit, debugger, **execution_kwargs)
162 if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase):
163 prep = circuit[0]
--> 165 state = create_initial_state(sorted(circuit.op_wires), prep, like=INTERFACE_TO_LIKE[interface])
167 # initial state is batched only if the state preparation (if it exists) is batched
168 is_state_batched = bool(prep and prep.batch_size is not None)
File ~/Documents/pennylane/pennylane/devices/qubit/initialize_state.py:46, in create_initial_state(wires, prep_operation, like)
43 state[(0,) * num_wires] = 1
44 return qml.math.asarray(state, like=like)
---> 46 return qml.math.asarray(prep_operation.state_vector(wire_order=list(wires)), like=like)
File ~/Documents/pennylane/pennylane/ops/qubit/state_preparation.py:104, in BasisState.state_vector(self, wire_order)
102 """Returns a statevector of shape ``(2,) * num_wires``."""
103 prep_vals = self.parameters[0]
--> 104 if any(i not in [0, 1] for i in prep_vals):
105 raise ValueError("BasisState parameter must consist of 0 or 1 integers.")
107 if (num_wires := len(self.wires)) != len(prep_vals):
File ~/Documents/pennylane/pennylane/ops/qubit/state_preparation.py:104, in <genexpr>(.0)
102 """Returns a statevector of shape ``(2,) * num_wires``."""
103 prep_vals = self.parameters[0]
--> 104 if any(i not in [0, 1] for i in prep_vals):
105 raise ValueError("BasisState parameter must consist of 0 or 1 integers.")
107 if (num_wires := len(self.wires)) != len(prep_vals):
[... skipping hidden 1 frame]
File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/core.py:1510, in concretization_function_error.<locals>.error(self, arg)
1509 def error(self, arg):
-> 1510 raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function circuit_BasisState at /var/folders/cn/h46l05vn2qd9c7ldxf0g905c0000gq/T/ipykernel_53736/2286483519.py:9 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
Expected behavior
I expect that
BasisState
, when used withjit
, should behave nicely likeBasisStatePreparation
does becauseBasisState
decomposes toBasisStatePreparation
.Actual behavior
BasisState
isn'tjit
friendly withdefault.qubit
Additional information
No response
Source code
Tracebacks
System information
Existing GitHub issues