PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.34k stars 602 forks source link

[BUG] `BasisState` does not work with `jax.jit` #6006

Closed isaacdevlugt closed 2 months ago

isaacdevlugt commented 3 months ago

Expected behavior

I expect that BasisState, when used with jit, should behave nicely like BasisStatePreparation does because BasisState decomposes to BasisStatePreparation.

Actual behavior

BasisState isn't jit friendly with default.qubit

Additional information

No response

Source code

import pennylane as qml
import jax
from jax import numpy as jnp

dev = qml.device("default.qubit", wires=3)

@jax.jit
@qml.qnode(dev)
def circuit_BasisState(n):
    qml.BasisState(n, wires) # doesn't work
    #qml.BasisStatePreparation(n, wires) # works
    return qml.state()

n = jnp.array([0, 1, 1])
print(qml.draw(circuit_BasisState)(n))

Tracebacks

---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[64], line 19
     16 n = jnp.array([0, 1, 1])
     18 #print(qml.draw(circuit_BasisEmbedding)(n))
---> 19 print(qml.draw(circuit_BasisState)(n))

File ~/Documents/pennylane/pennylane/drawer/draw.py:304, in draw.<locals>.wrapper(*args, **kwargs)
    302 @wraps(qnode)
    303 def wrapper(*args, **kwargs):
--> 304     tape = qml.tape.make_qscript(qnode)(*args, **kwargs)
    306     if wire_order:
    307         _wire_order = wire_order

File ~/Documents/pennylane/pennylane/tape/qscript.py:1298, in make_qscript.<locals>.wrapper(*args, **kwargs)
   1296 def wrapper(*args, **kwargs):
   1297     with AnnotatedQueue() as q:
-> 1298         fn(*args, **kwargs)
   1300     return QuantumScript.from_queue(q, shots)

    [... skipping hidden 12 frame]

File ~/Documents/pennylane/pennylane/workflow/qnode.py:1164, in QNode.__call__(self, *args, **kwargs)
   1162 if qml.capture.enabled():
   1163     return qml.capture.qnode_call(self, *args, **kwargs)
-> 1164 return self._impl_call(*args, **kwargs)

File ~/Documents/pennylane/pennylane/workflow/qnode.py:1150, in QNode._impl_call(self, *args, **kwargs)
   1147 self._update_gradient_fn(shots=override_shots, tape=self._tape)
   1149 try:
-> 1150     res = self._execution_component(args, kwargs, override_shots=override_shots)
   1151 finally:
   1152     if old_interface == "auto":

File ~/Documents/pennylane/pennylane/workflow/qnode.py:1103, in QNode._execution_component(self, args, kwargs, override_shots)
   1100 _prune_dynamic_transform(full_transform_program, inner_transform_program)
   1102 # pylint: disable=unexpected-keyword-arg
-> 1103 res = qml.execute(
   1104     (self._tape,),
   1105     device=self.device,
   1106     gradient_fn=self.gradient_fn,
   1107     interface=self.interface,
   1108     transform_program=full_transform_program,
   1109     inner_transform=inner_transform_program,
   1110     config=config,
   1111     gradient_kwargs=self.gradient_kwargs,
   1112     override_shots=override_shots,
   1113     **self.execute_kwargs,
   1114 )
   1115 res = res[0]
   1117 # convert result to the interface in case the qfunc has no parameters

File ~/Documents/pennylane/pennylane/workflow/execution.py:666, in execute(tapes, device, gradient_fn, interface, transform_program, inner_transform, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform, device_vjp, mcm_config)
    664 # Exiting early if we do not need to deal with an interface boundary
    665 if no_interface_boundary_required:
--> 666     results = inner_execute(tapes)
    667     return post_processing(results)
    669 _grad_on_execution = False

File ~/Documents/pennylane/pennylane/workflow/execution.py:316, in _make_inner_execute.<locals>.inner_execute(tapes, **_)
    313     transformed_tapes = tuple(expand_fn(t) for t in transformed_tapes)
    315 if transformed_tapes:
--> 316     results = device_execution(transformed_tapes)
    317 else:
    318     results = ()

File ~/Documents/pennylane/pennylane/devices/modifiers/simulator_tracking.py:30, in _track_execute.<locals>.execute(self, circuits, execution_config)
     28 @wraps(untracked_execute)
     29 def execute(self, circuits, execution_config=DefaultExecutionConfig):
---> 30     results = untracked_execute(self, circuits, execution_config)
     31     if isinstance(circuits, QuantumScript):
     32         batch = (circuits,)

File ~/Documents/pennylane/pennylane/devices/modifiers/single_tape_support.py:32, in _make_execute.<locals>.execute(self, circuits, execution_config)
     30     is_single_circuit = True
     31     circuits = (circuits,)
---> 32 results = batch_execute(self, circuits, execution_config)
     33 return results[0] if is_single_circuit else results

File ~/Documents/pennylane/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Documents/pennylane/pennylane/devices/default_qubit.py:597, in DefaultQubit.execute(self, circuits, execution_config)
    594 prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))]
    596 if max_workers is None:
--> 597     return tuple(
    598         _simulate_wrapper(
    599             c,
    600             {
    601                 "rng": self._rng,
    602                 "debugger": self._debugger,
    603                 "interface": interface,
    604                 "state_cache": self._state_cache,
    605                 "prng_key": _key,
    606                 "mcm_method": execution_config.mcm_config.mcm_method,
    607                 "postselect_mode": execution_config.mcm_config.postselect_mode,
    608             },
    609         )
    610         for c, _key in zip(circuits, prng_keys)
    611     )
    613 vanilla_circuits = convert_to_numpy_parameters(circuits)[0]
    614 seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))

File ~/Documents/pennylane/pennylane/devices/default_qubit.py:598, in <genexpr>(.0)
    594 prng_keys = [self.get_prng_keys()[0] for _ in range(len(circuits))]
    596 if max_workers is None:
    597     return tuple(
--> 598         _simulate_wrapper(
    599             c,
    600             {
    601                 "rng": self._rng,
    602                 "debugger": self._debugger,
    603                 "interface": interface,
    604                 "state_cache": self._state_cache,
    605                 "prng_key": _key,
    606                 "mcm_method": execution_config.mcm_config.mcm_method,
    607                 "postselect_mode": execution_config.mcm_config.postselect_mode,
    608             },
    609         )
    610         for c, _key in zip(circuits, prng_keys)
    611     )
    613 vanilla_circuits = convert_to_numpy_parameters(circuits)[0]
    614 seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))

File ~/Documents/pennylane/pennylane/devices/default_qubit.py:863, in _simulate_wrapper(circuit, kwargs)
    862 def _simulate_wrapper(circuit, kwargs):
--> 863     return simulate(circuit, **kwargs)

File ~/Documents/pennylane/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Documents/pennylane/pennylane/devices/qubit/simulate.py:354, in simulate(circuit, debugger, state_cache, **execution_kwargs)
    351     return tuple(results)
    353 ops_key, meas_key = jax_random_split(prng_key)
--> 354 state, is_state_batched = get_final_state(
    355     circuit, debugger=debugger, prng_key=ops_key, **execution_kwargs
    356 )
    357 if state_cache is not None:
    358     state_cache[circuit.hash] = state

File ~/Documents/pennylane/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Documents/pennylane/pennylane/devices/qubit/simulate.py:165, in get_final_state(circuit, debugger, **execution_kwargs)
    162 if len(circuit) > 0 and isinstance(circuit[0], qml.operation.StatePrepBase):
    163     prep = circuit[0]
--> 165 state = create_initial_state(sorted(circuit.op_wires), prep, like=INTERFACE_TO_LIKE[interface])
    167 # initial state is batched only if the state preparation (if it exists) is batched
    168 is_state_batched = bool(prep and prep.batch_size is not None)

File ~/Documents/pennylane/pennylane/devices/qubit/initialize_state.py:46, in create_initial_state(wires, prep_operation, like)
     43     state[(0,) * num_wires] = 1
     44     return qml.math.asarray(state, like=like)
---> 46 return qml.math.asarray(prep_operation.state_vector(wire_order=list(wires)), like=like)

File ~/Documents/pennylane/pennylane/ops/qubit/state_preparation.py:104, in BasisState.state_vector(self, wire_order)
    102 """Returns a statevector of shape ``(2,) * num_wires``."""
    103 prep_vals = self.parameters[0]
--> 104 if any(i not in [0, 1] for i in prep_vals):
    105     raise ValueError("BasisState parameter must consist of 0 or 1 integers.")
    107 if (num_wires := len(self.wires)) != len(prep_vals):

File ~/Documents/pennylane/pennylane/ops/qubit/state_preparation.py:104, in <genexpr>(.0)
    102 """Returns a statevector of shape ``(2,) * num_wires``."""
    103 prep_vals = self.parameters[0]
--> 104 if any(i not in [0, 1] for i in prep_vals):
    105     raise ValueError("BasisState parameter must consist of 0 or 1 integers.")
    107 if (num_wires := len(self.wires)) != len(prep_vals):

    [... skipping hidden 1 frame]

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/core.py:1510, in concretization_function_error.<locals>.error(self, arg)
   1509 def error(self, arg):
-> 1510   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function circuit_BasisState at /var/folders/cn/h46l05vn2qd9c7ldxf0g905c0000gq/T/ipykernel_53736/2286483519.py:9 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

System information

Name: PennyLane
Version: 0.37.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: [/Users/isaac/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages)
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning

Platform info:           macOS-14.5-arm64-arm-64bit
Python version:          3.11.8
Numpy version:           1.26.4
Scipy version:           1.12.0
Installed devices:
- default.clifford (PennyLane-0.38.0.dev0)
- default.gaussian (PennyLane-0.38.0.dev0)
- default.mixed (PennyLane-0.38.0.dev0)
- default.qubit (PennyLane-0.38.0.dev0)
- default.qubit.autograd (PennyLane-0.38.0.dev0)
- default.qubit.jax (PennyLane-0.38.0.dev0)
- default.qubit.legacy (PennyLane-0.38.0.dev0)
- default.qubit.tf (PennyLane-0.38.0.dev0)
- default.qubit.torch (PennyLane-0.38.0.dev0)
- default.qutrit (PennyLane-0.38.0.dev0)
- default.qutrit.mixed (PennyLane-0.38.0.dev0)
- default.tensor (PennyLane-0.38.0.dev0)
- null.qubit (PennyLane-0.38.0.dev0)
- lightning.qubit (PennyLane_Lightning-0.37.0)
- nvidia.custatevec (PennyLane-Catalyst-0.7.0)
- nvidia.cutensornet (PennyLane-Catalyst-0.7.0)
- oqc.cloud (PennyLane-Catalyst-0.7.0)
- softwareq.qpp (PennyLane-Catalyst-0.7.0)

Existing GitHub issues

DSGuala commented 2 months ago

Closing this issue as it was resolved in https://github.com/PennyLaneAI/pennylane/pull/6021