PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
122 stars 27 forks source link

[BUG] Error when taking `adjoint` of subroutines with `wires` arguments #334

Closed glassnotes closed 8 months ago

glassnotes commented 10 months ago

Issue description

I am on this branch/commit but was experiencing this previously on the main branch as well.

Name: PennyLane
Version: 0.32.0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /home/olivia/Software/anaconda3/envs/catalyst/lib/python3.11/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: pennylane-catalyst, PennyLane-Lightning, PennyLane-qiskit

Platform info:           Linux-6.2.0-35-generic-x86_64-with-glibc2.35
Python version:          3.11.0
Numpy version:           1.23.5
Scipy version:           1.10.1
Installed devices:
- default.gaussian (PennyLane-0.32.0)
- default.mixed (PennyLane-0.32.0)
- default.qubit (PennyLane-0.32.0)
- default.qubit.autograd (PennyLane-0.32.0)
- default.qubit.jax (PennyLane-0.32.0)
- default.qubit.tf (PennyLane-0.32.0)
- default.qubit.torch (PennyLane-0.32.0)
- default.qutrit (PennyLane-0.32.0)
- null.qubit (PennyLane-0.32.0)
- qiskit.aer (PennyLane-qiskit-0.32.0)
- qiskit.basicaer (PennyLane-qiskit-0.32.0)
- qiskit.ibmq (PennyLane-qiskit-0.32.0)
- qiskit.ibmq.circuit_runner (PennyLane-qiskit-0.32.0)
- qiskit.ibmq.sampler (PennyLane-qiskit-0.32.0)
- qiskit.remote (PennyLane-qiskit-0.32.0)
- lightning.qubit (PennyLane-Lightning-0.32.0)

Source code and tracebacks

This version works:

import catalyst
import pennylane as qml
from catalyst import qjit, adjoint

dev = qml.device("lightning.qubit", wires=3)

def subroutine_no_wires():
    for wire in range(3):
        qml.PauliX(wire)

@qjit(autograph=True)
@qml.qnode(dev)
def test_function():
    adjoint(subroutine_no_wires)()
    return qml.probs()

However, if we instead do:

def subroutine(wires):
    for wire in wires:
        qml.PauliX(wire)

@qjit(autograph=True)
@qml.qnode(dev)
def test_function():
    adjoint(subroutine)(dev.wires)
    return qml.probs()

we obtain the following traceback:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/Software/anaconda3/envs/catalyst/lib/python3.11/site-packages/jax/_src/api_util.py:581, in shaped_abstractify(x)
    580 try:
--> 581   return _shaped_abstractify_handlers[type(x)](x)
    582 except KeyError:

KeyError: <class 'pennylane.wires.Wires'>

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[6], line 5
      2     for wire in wires:
      3         qml.PauliX(wire)
----> 5 @qjit(autograph=True)
      6 @qml.qnode(dev)
      7 def test_function():
      8     subroutine(dev.wires)
      9     adjoint(subroutine)(dev.wires)

File ~/Code/catalyst/frontend/catalyst/compilation_pipelines.py:948, in qjit.<locals>.wrap_fn(fn)
    947 def wrap_fn(fn):
--> 948     return QJIT(
    949         fn, CompileOptions(verbose, logfile, target, keep_intermediate, pipelines, autograph)
    950     )

File ~/Code/catalyst/frontend/catalyst/compilation_pipelines.py:508, in QJIT.__init__(self, fn, compile_options)
    506 if parameter_types is not None:
    507     self.user_typed = True
--> 508     self.mlir_module = self.get_mlir(*parameter_types)
    509     if self.compile_options.target == "binary":
    510         self.compiled_function = self.compile()

File ~/Code/catalyst/frontend/catalyst/compilation_pipelines.py:555, in QJIT.get_mlir(self, *args)
    550 self.c_sig = CompiledFunction.get_runtime_signature(*args)
    552 with Patcher(
    553     (qml.QNode, "__call__", QFunc.__call__),
    554 ):
--> 555     mlir_module, ctx, jaxpr, self.shape = trace_to_mlir(self.user_function, *self.c_sig)
    557 inject_functions(mlir_module, ctx)
    558 self._jaxpr = jaxpr

File ~/Code/catalyst/frontend/catalyst/jax_tracer.py:276, in trace_to_mlir(func, *args, **kwargs)
    273 mlir_fn_cache.clear()
    275 with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
--> 276     jaxpr, shape = jax.make_jaxpr(func, return_shape=True)(*args, **kwargs)
    278 return jaxpr_to_mlir(func.__name__, jaxpr, shape)

    [... skipping hidden 6 frame]

File ~/Code/catalyst/frontend/catalyst/pennylane_extensions.py:149, in QFunc.__call__(self, *args, **kwargs)
    146     device = self.device
    148 with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION):
--> 149     jaxpr, shape = trace_quantum_function(self.func, device, args, kwargs)
    151 retval_tree = tree_structure(shape)
    153 def _eval_jaxpr(*args):

File ~/Code/catalyst/frontend/catalyst/jax_tracer.py:535, in trace_quantum_function(f, device, args, kwargs)
    532     in_classical_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
    533     with QueuingManager.stop_recording(), quantum_tape:
    534         # Quantum tape transformations happen at the end of tracing
--> 535         ans = wffa.call_wrapped(*in_classical_tracers)
    536     out_classical_tracers_or_measurements = [
    537         (trace.full_raise(t) if isinstance(t, DynamicJaxprTracer) else t) for t in ans
    538     ]
    540 # (2) - Quantum tracing

    [... skipping hidden 1 frame]

File /tmp/__autograph_generated_fileeeja8hs_.py:9, in outer_factory.<locals>.inner_factory.<locals>.test_function_1()
      7 with ag__.FunctionScope('test_function', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
      8     ag__.converted_call(subroutine, (dev.wires,), None, fscope)
----> 9     ag__.converted_call(ag__.converted_call(adjoint, (subroutine,), None, fscope), (dev.wires,), None, fscope)
     10     return ag__.converted_call(qml.probs, (), None, fscope)

File ~/Code/catalyst/frontend/catalyst/ag_primitives.py:441, in converted_call(fn, args, kwargs, caller_fn_scope, options)
    438     new_qnode = qml.QNode(qnode_call_wrapper, device=fn.device, diff_method=fn.diff_method)
    439     return new_qnode()
--> 441 return tf_converted_call(fn, args, kwargs, caller_fn_scope, options)

File ~/Software/anaconda3/envs/catalyst/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py:377, in converted_call(f, args, kwargs, caller_fn_scope, options)
    374   return _call_unconverted(f, args, kwargs, options)
    376 if not options.user_requested and conversion.is_allowlisted(f):
--> 377   return _call_unconverted(f, args, kwargs, options)
    379 # internal_convert_user_code is for example turned off when issuing a dynamic
    380 # call conversion from generated code while in nonrecursive mode. In that
    381 # case we evidently don't want to recurse, but we still have to convert
    382 # things like builtins.
    383 if not options.internal_convert_user_code:

File ~/Software/anaconda3/envs/catalyst/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py:460, in _call_unconverted(f, args, kwargs, options, update_cache)
    458 if kwargs is not None:
    459   return f(*args, **kwargs)
--> 460 return f(*args)

File ~/Code/catalyst/frontend/catalyst/pennylane_extensions.py:1814, in adjoint.<locals>._callable(*args, **kwargs)
   1813 def _callable(*args, **kwargs):
-> 1814     return _call_handler(*args, _callee=f, **kwargs)

File ~/Code/catalyst/frontend/catalyst/pennylane_extensions.py:1788, in adjoint.<locals>._call_handler(_callee, *args, **kwargs)
   1786 with EvaluationContext.frame_tracing_context(ctx) as inner_trace:
   1787     in_classical_tracers, _ = tree_flatten((args, kwargs))
-> 1788     wffa, in_avals, _ = deduce_avals(_callee, args, kwargs)
   1789     arg_classical_tracers = _input_type_to_tracers(inner_trace.new_arg, in_avals)
   1790     quantum_tape = QuantumTape()

File ~/Code/catalyst/frontend/catalyst/utils/jax_extras.py:266, in deduce_avals(f, args, kwargs)
    264 flat_args, in_tree = tree_flatten((args, kwargs))
    265 wf = wrap_init(f)
--> 266 in_avals, keep_inputs = list(map(shaped_abstractify, flat_args)), [True] * len(flat_args)
    267 in_type = tuple(zip(in_avals, keep_inputs))
    268 wff, out_tree_promise = flatten_fun(wf, in_tree)

    [... skipping hidden 1 frame]

File ~/Software/anaconda3/envs/catalyst/lib/python3.11/site-packages/jax/_src/api_util.py:572, in _shaped_abstractify_slow(x)
    570   dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
    571 else:
--> 572   raise TypeError(
    573       f"Cannot interpret value of type {type(x)} as an abstract array; it "
    574       "does not have a dtype attribute")
    575 return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
    576                         named_shape=named_shape)

TypeError: Cannot interpret value of type <class 'pennylane.wires.Wires'> as an abstract array; it does not have a dtype attribute

Additionally, if we pass the wires as a jnp.array, a separate error occurs in the adjoint call:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[18], line 5
      2     for wire in wires:
      3         qml.PauliX(wire)
----> 5 @qjit(autograph=True)
      6 @qml.qnode(dev)
      7 def test_function():
      8     #subroutine(dev.wires)
      9     adjoint(subroutine)(jnp.array(dev.wires))
     10     return qml.probs()

File ~/Code/catalyst/frontend/catalyst/compilation_pipelines.py:948, in qjit.<locals>.wrap_fn(fn)
    947 def wrap_fn(fn):
--> 948     return QJIT(
    949         fn, CompileOptions(verbose, logfile, target, keep_intermediate, pipelines, autograph)
    950     )

File ~/Code/catalyst/frontend/catalyst/compilation_pipelines.py:508, in QJIT.__init__(self, fn, compile_options)
    506 if parameter_types is not None:
    507     self.user_typed = True
--> 508     self.mlir_module = self.get_mlir(*parameter_types)
    509     if self.compile_options.target == "binary":
    510         self.compiled_function = self.compile()

File ~/Code/catalyst/frontend/catalyst/compilation_pipelines.py:555, in QJIT.get_mlir(self, *args)
    550 self.c_sig = CompiledFunction.get_runtime_signature(*args)
    552 with Patcher(
    553     (qml.QNode, "__call__", QFunc.__call__),
    554 ):
--> 555     mlir_module, ctx, jaxpr, self.shape = trace_to_mlir(self.user_function, *self.c_sig)
    557 inject_functions(mlir_module, ctx)
    558 self._jaxpr = jaxpr

File ~/Code/catalyst/frontend/catalyst/jax_tracer.py:276, in trace_to_mlir(func, *args, **kwargs)
    273 mlir_fn_cache.clear()
    275 with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
--> 276     jaxpr, shape = jax.make_jaxpr(func, return_shape=True)(*args, **kwargs)
    278 return jaxpr_to_mlir(func.__name__, jaxpr, shape)

    [... skipping hidden 6 frame]

File ~/Code/catalyst/frontend/catalyst/pennylane_extensions.py:149, in QFunc.__call__(self, *args, **kwargs)
    146     device = self.device
    148 with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION):
--> 149     jaxpr, shape = trace_quantum_function(self.func, device, args, kwargs)
    151 retval_tree = tree_structure(shape)
    153 def _eval_jaxpr(*args):

File ~/Code/catalyst/frontend/catalyst/jax_tracer.py:535, in trace_quantum_function(f, device, args, kwargs)
    532     in_classical_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
    533     with QueuingManager.stop_recording(), quantum_tape:
    534         # Quantum tape transformations happen at the end of tracing
--> 535         ans = wffa.call_wrapped(*in_classical_tracers)
    536     out_classical_tracers_or_measurements = [
    537         (trace.full_raise(t) if isinstance(t, DynamicJaxprTracer) else t) for t in ans
    538     ]
    540 # (2) - Quantum tracing

    [... skipping hidden 1 frame]

File /tmp/__autograph_generated_filezwu16kl_.py:8, in outer_factory.<locals>.inner_factory.<locals>.test_function_1()
      6 def test_function_1():
      7     with ag__.FunctionScope('test_function', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
----> 8         ag__.converted_call(ag__.converted_call(adjoint, (subroutine,), None, fscope), (ag__.converted_call(jnp.array, (dev.wires,), None, fscope),), None, fscope)
      9         return ag__.converted_call(qml.probs, (), None, fscope)

File ~/Code/catalyst/frontend/catalyst/ag_primitives.py:441, in converted_call(fn, args, kwargs, caller_fn_scope, options)
    438     new_qnode = qml.QNode(qnode_call_wrapper, device=fn.device, diff_method=fn.diff_method)
    439     return new_qnode()
--> 441 return tf_converted_call(fn, args, kwargs, caller_fn_scope, options)

File ~/Software/anaconda3/envs/catalyst/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py:377, in converted_call(f, args, kwargs, caller_fn_scope, options)
    374   return _call_unconverted(f, args, kwargs, options)
    376 if not options.user_requested and conversion.is_allowlisted(f):
--> 377   return _call_unconverted(f, args, kwargs, options)
    379 # internal_convert_user_code is for example turned off when issuing a dynamic
    380 # call conversion from generated code while in nonrecursive mode. In that
    381 # case we evidently don't want to recurse, but we still have to convert
    382 # things like builtins.
    383 if not options.internal_convert_user_code:

File ~/Software/anaconda3/envs/catalyst/lib/python3.11/site-packages/tensorflow/python/autograph/impl/api.py:460, in _call_unconverted(f, args, kwargs, options, update_cache)
    458 if kwargs is not None:
    459   return f(*args, **kwargs)
--> 460 return f(*args)

File ~/Software/anaconda3/envs/catalyst/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2051, in array(object, dtype, copy, order, ndmin)
   2044 out: ArrayLike
   2046 if all(not isinstance(leaf, Array) for leaf in leaves):
   2047   # TODO(jakevdp): falling back to numpy here fails to overflow for lists
   2048   # containing large integers; see discussion in
   2049   # https://github.com/google/jax/pull/6047. More correct would be to call
   2050   # coerce_to_array on each leaf, but this may have performance implications.
-> 2051   out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False)  # type: ignore[arg-type]
   2052 elif isinstance(object, Array):
   2053   assert object.aval is not None

TypeError: Wires.__array__() takes 1 positional argument but 2 were given

Additional information

Any additional information, configuration or data that might be necessary to reproduce the issue.

sergei-mironov commented 8 months ago

FYI: the following workaround should work

from functools import partial

def subroutine(*, wires):
    for wire in wires:
        qml.PauliX(wire)

dev = qml.device("lightning.qubit", wires=3)

@qjit(autograph=True)
@qml.qnode(dev)
def test_function():
    adjoint(partial(subroutine, wires=dev.wires))()
    return qml.probs()

test_function()