PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Apache License 2.0
2.26k stars 583 forks source link

[BUG] `BasisEmbedding` does not work with `lightning.qubit` and `jax.jit` #6008

Closed isaacdevlugt closed 9 hours ago

isaacdevlugt commented 1 month ago

Expected behavior

I expect that a circuit running on default.qubit or lightning.qubit, where the circuit contains BasisEmbedding, should work with jit.

Actual behavior

Using lightning.qubit fails.

Additional information

I think this will also affect BasisState

Source code

import pennylane as qml
import jax
from jax import numpy as jnp

dev = qml.device("lightning.qubit", wires=3)

def circuit_BasisEmbedding(n):
    qml.BasisEmbedding(n, wires)
    return qml.state()

n = jnp.array([0, 1, 1])



TracerArrayConversionError                Traceback (most recent call last)
File ~/Documents/pennylane/pennylane/math/, in _to_numpy_jax(x)
    782 try:
--> 783     return np.array(getattr(x, "val", x))
    784 except TracerArrayConversionError as e:

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/, in Tracer.__array__(self, *args, **kw)
    709 def __array__(self, *args, **kw):
--> 710   raise TracerArrayConversionError(self)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[].

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[71], line 18
     14     return qml.state()
     16 n = jnp.array([0, 1, 1])
---> 18 print(circuit_BasisEmbedding(n))#print(qml.draw(circuit_BasisState)(n))

    [... skipping hidden 12 frame]

File ~/Documents/pennylane/pennylane/workflow/, in QNode.__call__(self, *args, **kwargs)
   1162 if qml.capture.enabled():
   1163     return qml.capture.qnode_call(self, *args, **kwargs)
-> 1164 return self._impl_call(*args, **kwargs)

File ~/Documents/pennylane/pennylane/workflow/, in QNode._impl_call(self, *args, **kwargs)
   1147 self._update_gradient_fn(shots=override_shots, tape=self._tape)
   1149 try:
-> 1150     res = self._execution_component(args, kwargs, override_shots=override_shots)
   1151 finally:
   1152     if old_interface == "auto":

File ~/Documents/pennylane/pennylane/workflow/, in QNode._execution_component(self, args, kwargs, override_shots)
   1100 _prune_dynamic_transform(full_transform_program, inner_transform_program)
   1102 # pylint: disable=unexpected-keyword-arg
-> 1103 res = qml.execute(
   1104     (self._tape,),
   1105     device=self.device,
   1106     gradient_fn=self.gradient_fn,
   1107     interface=self.interface,
   1108     transform_program=full_transform_program,
   1109     inner_transform=inner_transform_program,
   1110     config=config,
   1111     gradient_kwargs=self.gradient_kwargs,
   1112     override_shots=override_shots,
   1113     **self.execute_kwargs,
   1114 )
   1115 res = res[0]
   1117 # convert result to the interface in case the qfunc has no parameters

File ~/Documents/pennylane/pennylane/workflow/, in execute(tapes, device, gradient_fn, interface, transform_program, inner_transform, config, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform, device_vjp, mcm_config)
    827 ml_boundary_execute = _get_ml_boundary_execute(
    828     interface,
    829     _grad_on_execution,
    830     config.use_device_jacobian_product,
    831     differentiable=max_diff > 1,
    832 )
    834 if interface in jpc_interfaces:
--> 835     results = ml_boundary_execute(tapes, execute_fn, jpc, device=device)
    836 else:
    837     results = ml_boundary_execute(
    838         tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff
    839     )

File ~/Documents/pennylane/pennylane/workflow/interfaces/, in jax_jvp_execute(tapes, execute_fn, jpc, device)
    261     logger.debug("Entry with (tapes=%s, execute_fn=%s, jpc=%s)", tapes, execute_fn, jpc)
    263 parameters = tuple(tuple(t.get_parameters()) for t in tapes)
--> 265 return _execute_jvp(parameters, _NonPytreeWrapper(tuple(tapes)), execute_fn, jpc)

    [... skipping hidden 6 frame]

File ~/Documents/pennylane/pennylane/workflow/interfaces/, in _execute_wrapper(params, tapes, execute_fn, jpc)
    228 """Executes ``tapes`` with ``params`` via ``execute_fn``"""
    229 new_tapes = set_parameters_on_copy_and_unwrap(tapes.vals, params, unwrap=False)
--> 230 return _to_jax(execute_fn(new_tapes))

File ~/Documents/pennylane/pennylane/workflow/, in _make_inner_execute.<locals>.inner_execute(tapes, **_)
    306 if cache is not None:
    307     transform_program.add_transform(_cache_transform, cache=cache)
--> 309 transformed_tapes, transform_post_processing = transform_program(tapes)
    311 # TODO: Apply expand_fn() as transform.
    312 if expand_fn:

File ~/Documents/pennylane/pennylane/transforms/core/, in TransformProgram.__call__(self, tapes)
    513 if self._argnums is not None and self._argnums[i] is not None:
    514     tape.trainable_params = self._argnums[i][j]
--> 515 new_tapes, fn = transform(tape, *targs, **tkwargs)
    516 execution_tapes.extend(new_tapes)
    518 fns.append(fn)

File ~/Documents/pennylane/pennylane/transforms/, in convert_to_numpy_parameters(tape)
     85 new_ops = (_convert_op_to_numpy_data(op) for op in tape.operations)
     86 new_measurements = (_convert_measurement_to_numpy_data(m) for m in tape.measurements)
---> 87 new_circuit = tape.__class__(
     88     new_ops, new_measurements, shots=tape.shots, trainable_params=tape.trainable_params
     89 )
     91 def null_postprocessing(results):
     92     """A postprocesing function returned by a transform that only converts the batch of results
     93     into a result for a single ``QuantumTape``.
     94     """

File ~/Documents/pennylane/pennylane/tape/, in QuantumScript.__init__(self, ops, measurements, shots, trainable_params)
    168 def __init__(
    169     self,
    170     ops=None,
    173     trainable_params: Optional[Sequence[int]] = None,
    174 ):
--> 175     self._ops = [] if ops is None else list(ops)
    176     self._measurements = [] if measurements is None else list(measurements)
    177     self._shots = Shots(shots)

File ~/Documents/pennylane/pennylane/transforms/, in <genexpr>(.0)
     50 @transform
     51 def convert_to_numpy_parameters(tape: QuantumScript) -> Tuple[Sequence[QuantumScript], Callable]:
     52     """Transforms a circuit to one with purely numpy parameters.
     54     Args:
     84     """
---> 85     new_ops = (_convert_op_to_numpy_data(op) for op in tape.operations)
     86     new_measurements = (_convert_measurement_to_numpy_data(m) for m in tape.measurements)
     87     new_circuit = tape.__class__(
     88         new_ops, new_measurements, shots=tape.shots, trainable_params=tape.trainable_params
     89     )

File ~/Documents/pennylane/pennylane/transforms/, in _convert_op_to_numpy_data(op)
     29     return op
     30 # Use operator method to change parameters when it become available
---> 31 return qml.ops.functions.bind_new_parameters(op, math.unwrap(

File ~/Documents/pennylane/pennylane/math/, in unwrap(values, max_depth)
    779     return new_val.tolist() if isinstance(new_val, ndarray) and not new_val.shape else new_val
    781 if isinstance(values, (tuple, list)):
--> 782     return type(values)(convert(val) for val in values)
    783 return (
    784     np.to_numpy(values, max_depth=max_depth)
    785     if isinstance(values, ArrayBox)
    786     else np.to_numpy(values)
    787 )

File ~/Documents/pennylane/pennylane/math/, in <genexpr>(.0)
    779     return new_val.tolist() if isinstance(new_val, ndarray) and not new_val.shape else new_val
    781 if isinstance(values, (tuple, list)):
--> 782     return type(values)(convert(val) for val in values)
    783 return (
    784     np.to_numpy(values, max_depth=max_depth)
    785     if isinstance(values, ArrayBox)
    786     else np.to_numpy(values)
    787 )

File ~/Documents/pennylane/pennylane/math/, in unwrap.<locals>.convert(val)
    774 if isinstance(val, (tuple, list)):
    775     return unwrap(val)
    776 new_val = (
--> 777     np.to_numpy(val, max_depth=max_depth) if isinstance(val, ArrayBox) else np.to_numpy(val)
    778 )
    779 return new_val.tolist() if isinstance(new_val, ndarray) and not new_val.shape else new_val

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/autoray/, in do(fn, like, *args, **kwargs)
     79 backend = _choose_backend(fn, args, kwargs, like=like)
     80 func = get_lib_fn(backend, fn)
---> 81 return func(*args, **kwargs)

File ~/Documents/pennylane/pennylane/math/, in _to_numpy_jax(x)
    783     return np.array(getattr(x, "val", x))
    784 except TracerArrayConversionError as e:
--> 785     raise ValueError(
    786         "Converting a JAX array to a NumPy array not supported when using the JAX JIT."
    787     ) from e

ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.

System information

Name: PennyLane
Version: 0.37.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
License: Apache License 2.0
Location: /Users/isaac/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning

Platform info:           macOS-14.5-arm64-arm-64bit
Python version:          3.11.8
Numpy version:           1.26.4
Scipy version:           1.12.0
Installed devices:
- default.clifford (PennyLane-0.38.0.dev0)
- default.gaussian (PennyLane-0.38.0.dev0)
- default.mixed (PennyLane-0.38.0.dev0)
- default.qubit (PennyLane-0.38.0.dev0)
- default.qubit.autograd (PennyLane-0.38.0.dev0)
- default.qubit.jax (PennyLane-0.38.0.dev0)
- default.qubit.legacy (PennyLane-0.38.0.dev0)
- (PennyLane-0.38.0.dev0)
- default.qubit.torch (PennyLane-0.38.0.dev0)
- default.qutrit (PennyLane-0.38.0.dev0)
- default.qutrit.mixed (PennyLane-0.38.0.dev0)
- default.tensor (PennyLane-0.38.0.dev0)
- null.qubit (PennyLane-0.38.0.dev0)
- lightning.qubit (PennyLane_Lightning-0.37.0)
- nvidia.custatevec (PennyLane-Catalyst-0.7.0)
- nvidia.cutensornet (PennyLane-Catalyst-0.7.0)
- (PennyLane-Catalyst-0.7.0)
- softwareq.qpp (PennyLane-Catalyst-0.7.0)

Existing GitHub issues

albi3ro commented 1 month ago

I'd bet the issue is that the features are treated as a hyperparameter, not as a numeric array. So we don't convert it to numpy, and any non-backprop device can't handle things that aren't numpy.

KetpuntoG commented 1 month ago

Yes, I think so :) I'm adding it as data in my story 👍

DSGuala commented 9 hours ago

Closing as this was resolved in