PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
121 stars 27 forks source link

[BUG] Using `mcm_method="one-shot"` when no MCMs are present causes an error #928

Open isaacdevlugt opened 1 month ago

isaacdevlugt commented 1 month ago

Issue description

Specifying mcm_method="one-shot" when no MCMs are present in a qjit'd QNode raises an error when the return type is an iterable.

>>> qml.about()
Name: PennyLane
Version: 0.37.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: [/Users/isaac/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages)
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning

Platform info:           macOS-14.5-arm64-arm-64bit
Python version:          3.11.8
Numpy version:           1.26.4
Scipy version:           1.12.0
Installed devices:
- lightning.qubit (PennyLane_Lightning-0.37.0)
- nvidia.custatevec (PennyLane-Catalyst-0.7.0)
- nvidia.cutensornet (PennyLane-Catalyst-0.7.0)
- oqc.cloud (PennyLane-Catalyst-0.7.0)
- softwareq.qpp (PennyLane-Catalyst-0.7.0)
- default.clifford (PennyLane-0.37.0)
- default.gaussian (PennyLane-0.37.0)
- default.mixed (PennyLane-0.37.0)
- default.qubit (PennyLane-0.37.0)
- default.qubit.autograd (PennyLane-0.37.0)
- default.qubit.jax (PennyLane-0.37.0)
- default.qubit.legacy (PennyLane-0.37.0)
- default.qubit.tf (PennyLane-0.37.0)
- default.qubit.torch (PennyLane-0.37.0)
- default.qutrit (PennyLane-0.37.0)
- default.qutrit.mixed (PennyLane-0.37.0)
- default.tensor (PennyLane-0.37.0)
- null.qubit (PennyLane-0.37.0)
>>>
>>> catalyst.__version__
'0.7.0'

Source code and tracebacks

qubits = 3
dev = qml.device("lightning.qubit", wires=qubits, shots=10)

@qml.qjit
@qml.qnode(dev, mcm_method="one-shot") # works without specifying mcm_method="one-shot"
def cost():
    qml.Hadamard(0)
    qml.CNOT([0, 1])
    #return qml.expval(qml.Z(0)) # works
    return [qml.expval(qml.Z(i)) for i in range(qubits)] # doesn't work

print(cost())

Traceback:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
    [... skipping hidden 1 frame]

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/util.py:286, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
    285 else:
--> 286   return cached(config.config._trace_context(), *args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/util.py:279, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
    277 @functools.lru_cache(max_size)
    278 def cached(_, *args, **kwargs):
--> 279   return f(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/lax/lax.py:155, in _broadcast_shapes_cached(*shapes)
    153 @cache()
    154 def _broadcast_shapes_cached(*shapes: tuple[int, ...]) -> tuple[int, ...]:
--> 155   return _broadcast_shapes_uncached(*shapes)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/lax/lax.py:171, in _broadcast_shapes_uncached(*shapes)
    170 if result_shape is None:
--> 171   raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
    172 return result_shape

ValueError: Incompatible shapes for broadcasting: shapes=[(10, 3), (1, 0)]

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[3], line 4
      1 qubits = 3
      2 dev = qml.device("lightning.qubit", wires=qubits, shots=10)
----> 4 @qml.qjit
      5 @qml.qnode(dev, mcm_method="one-shot") # works without specifying mcm_method="one-shot"
      6 def cost():
      7     qml.Hadamard(0)
      8     qml.CNOT([0, 1])

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/compiler/qjit_api.py:301, in qjit(fn, compiler, *args, **kwargs)
    299 compilers = AvailableCompilers.names_entrypoints
    300 qjit_loader = compilers[compiler]["qjit"].load()
--> 301 return qjit_loader(fn=fn, *args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jit.py:424, in qjit(fn, autograph, autograph_include, async_qnodes, target, keep_intermediate, verbose, logfile, pipelines, static_argnums, abstracted_axes)
    421 if fn is None:
    422     return functools.partial(qjit, **kwargs)
--> 424 return QJIT(fn, CompileOptions(**kwargs))

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:65, in log_string_debug_func.<locals>.wrapper_exit(*args, **kwargs)
     63 @wraps(func)
     64 def wrapper_exit(*args, **kwargs):
---> 65     output = func(*args, **kwargs)
     66     if lgr.isEnabledFor(log_level):  # pragma: no cover
     67         f_string = _get_bound_signature(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jit.py:491, in QJIT.__init__(self, fn, compile_options)
    489 # Static arguments require values, so we cannot AOT compile.
    490 if self.user_sig is not None and not self.compile_options.static_argnums:
--> 491     self.aot_compile()

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jit.py:525, in QJIT.aot_compile(self)
    520 if self.compile_options.target in ("jaxpr", "mlir", "binary"):
    521     # Capture with the patched conversion rules
    522     with Patcher(
    523         (ag_primitives, "module_allowlist", self.patched_module_allowlist),
    524     ):
--> 525         self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
    526             self.user_sig or ()
    527         )
    529 if self.compile_options.target in ("mlir", "binary"):
    530     self.mlir_module, self.mlir = self.generate_ir()

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/debug/instruments.py:143, in instrument.<locals>.wrapper(*args, **kwargs)
    140 @functools.wraps(fn)
    141 def wrapper(*args, **kwargs):
    142     if not InstrumentSession.active:
--> 143         return fn(*args, **kwargs)
    145     with ResultReporter(stage_name, has_finegrained) as reporter:
    146         fn_results, wall_time, cpu_time = time_function(fn, args, kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jit.py:628, in QJIT.capture(self, args)
    622 full_sig = merge_static_args(dynamic_sig, args, static_argnums)
    624 with Patcher(
    625     (qml.QNode, "__call__", QFunc.__call__),
    626 ):
    627     # TODO: improve PyTree handling
--> 628     jaxpr, out_type, treedef = trace_to_jaxpr(
    629         self.user_function, static_argnums, abstracted_axes, full_sig, {}
    630     )
    632 return jaxpr, out_type, treedef, dynamic_sig

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jax_tracer.py:531, in trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs)
    526     with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
    527         make_jaxpr_kwargs = {
    528             "static_argnums": static_argnums,
    529             "abstracted_axes": abstracted_axes,
    530         }
--> 531         jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
    533 return jaxpr, out_type, out_treedef

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jax_extras/tracing.py:539, in make_jaxpr2.<locals>.make_jaxpr_f(*args, **kwargs)
    537     f, out_tree_promise = flatten_fun(f, in_tree)
    538     f = annotate(f, in_type)
--> 539     jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
    540 closed_jaxpr = ClosedJaxpr(jaxpr, consts)
    541 return closed_jaxpr, out_type, out_tree_promise()

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/profiler.py:336, in annotate_function.<locals>.wrapper(*args, **kwargs)
    333 @wraps(func)
    334 def wrapper(*args, **kwargs):
    335   with TraceAnnotation(name, **decorator_kwargs):
--> 336     return func(*args, **kwargs)
    337   return wrapper

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2324, in trace_to_jaxpr_dynamic2(fun, debug_info)
   2322 with core.new_main(DynamicJaxprTrace, dynamic=True) as main:  # type: ignore
   2323   main.jaxpr_stack = ()  # type: ignore
-> 2324   jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   2325   del main, fun
   2326 return jaxpr, out_type, consts

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2339, in trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   2337 in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
   2338 in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
-> 2339 ans = fun.call_wrapped(*in_tracers_)
   2340 out_tracers = map(trace.full_raise, ans)
   2341 jaxpr, out_type, consts = frame.to_jaxpr2(out_tracers)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/linear_util.py:191, in WrappedFun.call_wrapped(self, *args, **kwargs)
    188 gen = gen_static_args = out_store = None
    190 try:
--> 191   ans = self.f(*args, **dict(self.params, **kwargs))
    192 except:
    193   # Some transformations yield from inside context managers, so we have to
    194   # interrupt them before reraising the exception. Otherwise they will only
    195   # get garbage-collected at some later time, running their cleanup tasks
    196   # only after this exception is handled, which can corrupt the global
    197   # state.
    198   while stack:

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/qfunc.py:122, in QFunc.__call__(self, *args, **kwargs)
    120     if mcm_config.mcm_method == "one-shot":
    121         mcm_config.postselect_mode = mcm_config.postselect_mode or "hw-like"
--> 122         return dynamic_one_shot(self, mcm_config=mcm_config)(*args, **kwargs)
    124 # TODO: Move the capability loading and validation to the device constructor when the
    125 # support for old device api is dropped.
    126 program_features = ProgramFeatures(self.device.shots is not None)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/qfunc.py:259, in dynamic_one_shot.<locals>.one_shot_wrapper(*args, **kwargs)
    257 if isinstance(results[0], tuple) and len(results) == 1:
    258     results = results[0]
--> 259 return parse_native_mid_circuit_measurements(cpy_tape, aux_tapes, results, interface="jax")

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/transforms/dynamic_one_shot.py:264, in parse_native_mid_circuit_measurements(circuit, aux_tapes, results, interface)
    258 has_postselect = qml.math.array(
    259     [[int(op.postselect is not None) for op in all_mcms]], like=interface
    260 )
    261 postselect = qml.math.array(
    262     [[0 if op.postselect is None else op.postselect for op in all_mcms]], like=interface
    263 )
--> 264 is_valid = qml.math.all(mcm_samples * has_postselect == postselect, axis=1)
    265 has_valid = qml.math.any(is_valid)
    266 mid_meas = [op for op in circuit.operations if is_mcm(op)]

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:743, in _forward_operator_to_aval.<locals>.op(self, *args)
    742 def op(self, *args):
--> 743   return getattr(self.aval, f"_{name}")(self, *args)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:271, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
    269 args = (other, self) if swap else (self, other)
    270 if isinstance(other, _accepted_binop_types):
--> 271   return binary_op(*args)
    272 # Note: don't use isinstance here, because we don't want to raise for
    273 # subclasses, e.g. NamedTuple objects that may override operators.
    274 if type(other) in _rejected_binop_types:

    [... skipping hidden 12 frame]

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/numpy/ufuncs.py:98, in _maybe_bool_binop.<locals>.fn(x1, x2)
     97 def fn(x1, x2, /):
---> 98   x1, x2 = promote_args(numpy_fn.__name__, x1, x2)
     99   return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/numpy/util.py:381, in promote_args(fun_name, *args)
    379 _check_no_float0s(fun_name, *args)
    380 check_for_prngkeys(fun_name, *args)
--> 381 return promote_shapes(fun_name, *promote_dtypes(*args))

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/numpy/util.py:239, in promote_shapes(fun_name, *args)
    234 shapes = [np.shape(arg) for arg in args]
    235 if config.dynamic_shapes.value:
    236   # With dynamic shapes we don't support singleton-dimension broadcasting;
    237   # we instead broadcast out to the full shape as a temporary workaround.
    238   # TODO(mattjj): revise this workaround
--> 239   res_shape = lax.broadcast_shapes(*shapes)  # Can raise an error!
    240   return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
    241 else:

    [... skipping hidden 1 frame]

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/lax/lax.py:171, in _broadcast_shapes_uncached(*shapes)
    169 result_shape = _try_broadcast_shapes(shape_list)
    170 if result_shape is None:
--> 171   raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
    172 return result_shape

ValueError: Incompatible shapes for broadcasting: shapes=[(10, 3), (1, 0)]

Additional information

dime10 commented 1 month ago

Oh that's good to know thanks 😬 I just suggested to UF they use this feature for easier device-based noise testing, where MCM's wouldn't necessarily be present. (seems like it only affects circuits with multiple results though)

Speaking of, it does make me wonder wether we should think of this feature as an "MCM" feature, rather than just an "execution mode" feature. @josh146

josh146 commented 1 week ago

@dime10 just double checking since I missed the iteration planning meeting -- are we in the process of resolving this one this iteration?

isaacdevlugt commented 1 week ago

Didn't realize that my traceback section was improperly copy-pasted from VS code 😅. Fixed!

dime10 commented 1 week ago

@dime10 just double checking since I missed the iteration planning meeting -- are we in the process of resolving this one this iteration?

This is one hasn't been scheduled yet because #929 seems more pressing.