PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
122 stars 27 forks source link

[BUG] Differentiating a qjit'd circuit with `qml.Rot` raises an error #969

Closed isaacdevlugt closed 1 month ago

isaacdevlugt commented 1 month ago

Issue description

A qjit'd circuit that contains qml.Rot cannot be differentiated.

Platform info: macOS-14.5-arm64-arm-64bit Python version: 3.11.8 Numpy version: 1.26.4 Scipy version: 1.12.0 Installed devices:

Source code and tracebacks

import pennylane as qml
from pennylane import numpy as pnp

n_qubits = 1

dev = qml.device("lightning.qubit", wires=n_qubits)

@qml.qjit
@qml.qnode(dev)
def circuit(params):
    qml.Rot(params[0], params[1], params[2], wires=0)
    return qml.expval(qml.PauliZ(0))

params = pnp.random.random(size=(3,))
qml.grad(circuit)(params)
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[58], line 1
----> 1 qml.grad(circuit)(params)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/_grad.py:166, in grad.__call__(self, *args, **kwargs)
    163     self._forward = self._fun(*args, **kwargs)
    164     return ()
--> 166 grad_value, ans = grad_fn(*args, **kwargs)  # pylint: disable=not-callable
    167 self._forward = ans
    169 return grad_value

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/autograd/wrap_util.py:20, in unary_to_nary.<locals>.nary_operator.<locals>.nary_f(*args, **kwargs)
     18 else:
     19     x = tuple(args[i] for i in argnum)
---> 20 return unary_operator(unary_f, x, *nary_op_args, **nary_op_kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/_grad.py:184, in grad._grad_with_forward(fun, x)
    178 @staticmethod
    179 @unary_to_nary
    180 def _grad_with_forward(fun, x):
    181     """This function is a replica of ``autograd.grad``, with the only
    182     difference being that it returns both the gradient *and* the forward pass
    183     value."""
--> 184     vjp, ans = _make_vjp(fun, x)  # pylint: disable=redefined-outer-name
    186     if vspace(ans).size != 1:
    187         raise TypeError(
    188             "Grad only applies to real scalar-output functions. "
    189             "Try jacobian, elementwise_grad or holomorphic_grad."
    190         )

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/autograd/core.py:10, in make_vjp(fun, x)
      8 def make_vjp(fun, x):
      9     start_node = VJPNode.new_root()
---> 10     end_value, end_node =  trace(start_node, fun, x)
     11     if end_node is None:
     12         def vjp(g): return vspace(x).zeros()

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/autograd/tracer.py:10, in trace(start_node, fun, x)
      8 with trace_stack.new_trace() as t:
      9     start_box = new_box(x, t, start_node)
---> 10     end_box = fun(start_box)
     11     if isbox(end_box) and end_box._trace == start_box._trace:
     12         return end_box._value, end_box._node

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/autograd/wrap_util.py:15, in unary_to_nary.<locals>.nary_operator.<locals>.nary_f.<locals>.unary_f(x)
     13 else:
     14     subargs = subvals(args, zip(argnum, x))
---> 15 return fun(*subargs, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jit.py:511, in QJIT.__call__(self, *args, **kwargs)
    508     dynamic_args = filter_static_args(args, self.compile_options.static_argnums)
    509     args = promote_arguments(self.c_sig, dynamic_args)
--> 511 return self.run(args, kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/debug/instruments.py:143, in instrument.<locals>.wrapper(*args, **kwargs)
    140 @functools.wraps(fn)
    141 def wrapper(*args, **kwargs):
    142     if not InstrumentSession.active:
--> 143         return fn(*args, **kwargs)
    145     with ResultReporter(stage_name, has_finegrained) as reporter:
    146         fn_results, wall_time, cpu_time = time_function(fn, args, kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jit.py:703, in QJIT.run(self, args, kwargs)
    690 @instrument(has_finegrained=True)
    691 @debug_logger
    692 def run(self, args, kwargs):
    693     """Invoke a previously compiled function with the supplied arguments.
    694 
    695     Args:
   (...)
    700         Any: results of the execution arranged into the original function's output PyTrees
    701     """
--> 703     results = self.compiled_function(*args, **kwargs)
    705     # TODO: Move this to the compiled function object.
    706     return tree_unflatten(self.out_treedef, results)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/compiled_functions.py:338, in CompiledFunction.__call__(self, *args, **kwargs)
    333     abstracted_axes = self.compile_options.abstracted_axes
    334     dynamic_args = get_implicit_and_explicit_flat_args(
    335         abstracted_axes, *dynamic_args, **kwargs
    336     )
--> 338 abi_args, _buffer = self.args_to_memref_descs(self.restype, dynamic_args)
    340 numpy_dict = {nparr.ctypes.data: nparr for nparr in _buffer}
    342 result = CompiledFunction._exec(
    343     self.shared_object,
    344     self.restype,
   (...)
    347     *abi_args,
    348 )

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/compiled_functions.py:298, in CompiledFunction.args_to_memref_descs(self, restype, args)
    296     numpy_arg = np.asarray(arg)
    297     numpy_arg_buffer.append(numpy_arg)
--> 298     c_abi_ptr = ctypes.pointer(get_ranked_memref_descriptor(numpy_arg))
    299     c_abi_args.append(c_abi_ptr)
    301 args = tree_unflatten(args_shape, c_abi_args)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/utils/jnp_to_memref.py:67, in get_ranked_memref_descriptor(array)
     64     return get_ranked_memref_descriptor_from_shaped_array(array)
     66 # Use default implementation from MLIR's library.
---> 67 return mlir_get_ranked_memref_descriptor(array)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/mlir_quantum/runtime/np_to_memref.py:88, in get_ranked_memref_descriptor(nparray)
     86 def get_ranked_memref_descriptor(nparray):
     87     """Returns a ranked memref descriptor for the given numpy array."""
---> 88     ctp = as_ctype(nparray.dtype)
     89     if nparray.ndim == 0:
     90         x = make_zero_d_memref_descriptor(ctp)()

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/mlir_quantum/runtime/np_to_memref.py:38, in as_ctype(dtp)
     36 if dtp == np.dtype(np.float16):
     37     return F16
---> 38 return np.ctypeslib.as_ctypes_type(dtp)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/numpy/ctypeslib.py:503, in as_ctypes_type(dtype)
    465 def as_ctypes_type(dtype):
    466     r"""
    467     Convert a dtype into a ctypes type.
    468 
   (...)
    501 
    502     """
--> 503     return _ctype_from_dtype(_dtype(dtype))

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/numpy/ctypeslib.py:462, in _ctype_from_dtype(dtype)
    460     return _ctype_from_dtype_subarray(dtype)
    461 else:
--> 462     return _ctype_from_dtype_scalar(dtype)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/numpy/ctypeslib.py:384, in _ctype_from_dtype_scalar(dtype)
    382     ctype = _scalar_type_map[dtype_native]
    383 except KeyError as e:
--> 384     raise NotImplementedError(
    385         "Converting {!r} to a ctypes type".format(dtype)
    386     ) from None
    388 if dtype_with_endian.byteorder == '>':
    389     ctype = ctype.__ctype_be__

NotImplementedError: Converting dtype('O') to a ctypes type

Using catalyst.grad instead:

---------------------------------------------------------------------------
DifferentiableCompileError                Traceback (most recent call last)
Cell In[61], line 15
     12     return qml.expval(qml.PauliZ(0))
     14 params = pnp.random.random(size=(n_qubits,))
---> 15 catalyst.grad(circuit)(params)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/api_extensions/differentiation.py:644, in Grad.__call__(self, *args, **kwargs)
    642         results = jax.value_and_grad(self.fn, argnums=argnums)(*args)
    643     else:
--> 644         results = jax.grad(self.fn, argnums=argnums)(*args)
    645 else:
    646     assert (
    647         not self.grad_params.with_value
    648     ), "value_and_grad cannot be used with a Jacobian"

    [... skipping hidden 10 frame]

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jit.py:505, in QJIT.__call__(self, *args, **kwargs)
    503     if self.jaxed_function is None:
    504         self.jaxed_function = JAX_QJIT(self)  # lazy gradient compilation
--> 505     return self.jaxed_function(*args, **kwargs)
    507 elif requires_promotion:
    508     dynamic_args = filter_static_args(args, self.compile_options.static_argnums)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jit.py:843, in JAX_QJIT.__call__(self, *args, **kwargs)
    841 @debug_logger
    842 def __call__(self, *args, **kwargs):
--> 843     return self.jaxed_function(*args, **kwargs)

    [... skipping hidden 5 frame]

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jit.py:818, in JAX_QJIT.compute_jvp(self, primals, tangents)
    816 results = self.wrap_callback(self.qjit_function, *primals)
    817 results_data, _results_shape = tree_flatten(results)
--> 818 derivatives = self.wrap_callback(self.get_derivative_qjit(argnums), *primals)
    819 derivatives_data, _derivatives_shape = tree_flatten(derivatives)
    821 jvps = [jnp.zeros_like(results_data[res_idx]) for res_idx in range(len(results_data))]

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jit.py:797, in JAX_QJIT.get_derivative_qjit(self, argnums)
    794 deriv_wrapper.__annotations__ = annotations
    795 deriv_wrapper.__signature__ = signature.replace(parameters=updated_params)
--> 797 self.derivative_functions[argnum_key] = QJIT(
    798     deriv_wrapper, self.qjit_function.compile_options
    799 )
    800 return self.derivative_functions[argnum_key]

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:65, in log_string_debug_func.<locals>.wrapper_exit(*args, **kwargs)
     63 @wraps(func)
     64 def wrapper_exit(*args, **kwargs):
---> 65     output = func(*args, **kwargs)
     66     if lgr.isEnabledFor(log_level):  # pragma: no cover
     67         f_string = _get_bound_signature(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jit.py:491, in QJIT.__init__(self, fn, compile_options)
    489 # Static arguments require values, so we cannot AOT compile.
    490 if self.user_sig is not None and not self.compile_options.static_argnums:
--> 491     self.aot_compile()

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jit.py:525, in QJIT.aot_compile(self)
    520 if self.compile_options.target in ("jaxpr", "mlir", "binary"):
    521     # Capture with the patched conversion rules
    522     with Patcher(
    523         (ag_primitives, "module_allowlist", self.patched_module_allowlist),
    524     ):
--> 525         self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
    526             self.user_sig or ()
    527         )
    529 if self.compile_options.target in ("mlir", "binary"):
    530     self.mlir_module, self.mlir = self.generate_ir()

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/debug/instruments.py:143, in instrument.<locals>.wrapper(*args, **kwargs)
    140 @functools.wraps(fn)
    141 def wrapper(*args, **kwargs):
    142     if not InstrumentSession.active:
--> 143         return fn(*args, **kwargs)
    145     with ResultReporter(stage_name, has_finegrained) as reporter:
    146         fn_results, wall_time, cpu_time = time_function(fn, args, kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jit.py:628, in QJIT.capture(self, args)
    622 full_sig = merge_static_args(dynamic_sig, args, static_argnums)
    624 with Patcher(
    625     (qml.QNode, "__call__", QFunc.__call__),
    626 ):
    627     # TODO: improve PyTree handling
--> 628     jaxpr, out_type, treedef = trace_to_jaxpr(
    629         self.user_function, static_argnums, abstracted_axes, full_sig, {}
    630     )
    632 return jaxpr, out_type, treedef, dynamic_sig

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jax_tracer.py:531, in trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs)
    526     with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
    527         make_jaxpr_kwargs = {
    528             "static_argnums": static_argnums,
    529             "abstracted_axes": abstracted_axes,
    530         }
--> 531         jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
    533 return jaxpr, out_type, out_treedef

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jax_extras/tracing.py:539, in make_jaxpr2.<locals>.make_jaxpr_f(*args, **kwargs)
    537     f, out_tree_promise = flatten_fun(f, in_tree)
    538     f = annotate(f, in_type)
--> 539     jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
    540 closed_jaxpr = ClosedJaxpr(jaxpr, consts)
    541 return closed_jaxpr, out_type, out_tree_promise()

    [... skipping hidden 4 frame]

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jit.py:791, in JAX_QJIT.get_derivative_qjit.<locals>.deriv_wrapper(*args, **kwargs)
    790 def deriv_wrapper(*args, **kwargs):
--> 791     return catalyst.jacobian(self.qjit_function, argnum=argnums)(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/api_extensions/differentiation.py:607, in Grad.__call__(self, *args, **kwargs)
    597 args_data, in_tree = tree_flatten(args)
    598 grad_params = _check_grad_params(
    599     self.grad_params.method,
    600     self.grad_params.scalar_out,
   (...)
    605     self.grad_params.with_value,
    606 )
--> 607 jaxpr, out_tree = _make_jaxpr_check_differentiable(fn, grad_params, *args)
    608 if self.grad_params.with_value:  # use value_and_grad
    609     # It always returns list as required by catalyst control-flows
    610     results = value_and_grad_p.bind(
    611         *args_data, jaxpr=jaxpr, fn=fn, grad_params=grad_params
    612     )

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/api_extensions/differentiation.py:739, in _make_jaxpr_check_differentiable(f, grad_params, *args)
    737 method = grad_params.method
    738 with mark_gradient_tracing(method):
--> 739     jaxpr, shape = jax.make_jaxpr(f, return_shape=True)(*args)
    740 _, out_tree = tree_flatten(shape)
    741 assert len(jaxpr.eqns) == 1, "Expected jaxpr consisting of a single function call."

    [... skipping hidden 6 frame]

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/qfunc.py:150, in QFunc.__call__(self, *args, **kwargs)
    148 flattened_fun, _, _, out_tree_promise = deduce_avals(_eval_quantum, args, {})
    149 args_flat = tree_flatten(args)[0]
--> 150 res_flat = func_p.bind(flattened_fun, *args_flat, fn=self)
    151 return tree_unflatten(out_tree_promise(), res_flat)[0]

    [... skipping hidden 4 frame]

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/qfunc.py:139, in QFunc.__call__.<locals>._eval_quantum(*args)
    138 def _eval_quantum(*args):
--> 139     closed_jaxpr, out_type, out_tree = trace_quantum_function(
    140         self.func, qjit_device, args, kwargs, self
    141     )
    142     args_expanded = get_implicit_and_explicit_flat_args(None, *args)
    143     res_expanded = eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args_expanded)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jax_tracer.py:1126, in trace_quantum_function(f, device, args, kwargs, qnode)
   1120     qnode_program = qnode.transform_program if qnode else TransformProgram()
   1122     device_modify_measurements = "measurements_from_counts" in [
   1123         t.transform.__name__ for t in device_program
   1124     ]
-> 1126     tapes, post_processing = apply_transform(
   1127         qnode_program,
   1128         device_program,
   1129         device_modify_measurements,
   1130         quantum_tape,
   1131         return_values_flat,
   1132     )
   1134 # (2) - Quantum tracing
   1135 transformed_results = []

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/jax_tracer.py:939, in apply_transform(qnode_program, device_program, device_modify_measurements, tape, flat_results)
    936     # Apply the identity transform in order to keep generalization
    937     total_program = device_program
--> 939 tapes, post_processing = total_program([tape])
    940 if not is_valid_for_batch and len(tapes) > 1:
    941     msg = "Multiple tapes are generated, but each run might produce different results."

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/pennylane/transforms/core/transform_program.py:515, in TransformProgram.__call__(self, tapes)
    513 if self._argnums is not None and self._argnums[i] is not None:
    514     tape.trainable_params = self._argnums[i][j]
--> 515 new_tapes, fn = transform(tape, *targs, **tkwargs)
    516 execution_tapes.extend(new_tapes)
    518 fns.append(fn)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/device/verification.py:225, in verify_operations(tape, grad_method, qjit_device)
    221             _paramshift_op_checker(op)
    223     return (in_inverse, in_control)
--> 225 _verify_nested(tape, (False, False), _op_checker)
    227 return (tape,), lambda x: x[0]

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/device/verification.py:51, in _verify_nested(tape, state, op_checker_fn)
     49 ctx = EvaluationContext.get_main_tracing_context()
     50 for op in tape.operations:
---> 51     state = op_checker_fn(op, state)
     52     if has_nested_tapes(op):
     53         for region in nested_quantum_regions(op):

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/device/verification.py:219, in verify_operations.<locals>._op_checker(op, state)
    217 _mcm_op_checker(op)
    218 if grad_method == "adjoint":
--> 219     _adj_diff_op_checker(op)
    220 elif grad_method == "parameter-shift":
    221     _paramshift_op_checker(op)

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/catalyst/device/verification.py:143, in verify_operations.<locals>._adj_diff_op_checker(op)
    139     op_name = op.name
    140 if not qjit_device.qjit_capabilities.native_ops.get(
    141     op_name, EMPTY_PROPERTIES
    142 ).differentiable:
--> 143     raise DifferentiableCompileError(
    144         f"{op.name} is non-differentiable on '{qjit_device.original_device.name}' device"
    145     )

DifferentiableCompileError: Rot is non-differentiable on 'lightning.qubit' device

Should this be supported then? 🤔

Additional information

isaacdevlugt commented 1 month ago

This works:

import pennylane as qml
from jax import numpy as jnp
import functools

n_qubits = 1

dev = qml.device("lightning.qubit", wires=n_qubits)

@qml.qjit
@functools.partial(qml.devices.preprocess.decompose, stopping_condition = lambda obj : obj in dev.operations, max_expansion=1)
@qml.qnode(dev)
def circuit(params):
    qml.Rot(params[0], params[1], params[2], wires=0)
    return qml.expval(qml.PauliZ(0))

params = jnp.array([0.1, 0.2, 0.3])
print(qml.draw(circuit)(params))
catalyst.grad(circuit)(params)
0: ──RZ(0.10)──RY(0.20)──RZ(0.30)─┤  <Z>
Array([ 0.00000000e+00, -1.98669331e-01, -5.55111512e-17], dtype=float64)

Without the stopping condition, we get `DifferentiableCompileError: Rot is non-differentiable on 'lightning.qubit' device.

isaacdevlugt commented 1 month ago

This is not a bug, actually, as Lightning doesn't support differentiating qml.Rot. The user must force Rot to decompose into individual rotations with a stopping_condition!