PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
130 stars 29 forks source link

[BUG] ``qjit`` fails to access ``__name__`` from ``functools.partial`` #815

Closed tzunghanjuang closed 3 months ago

tzunghanjuang commented 3 months ago

Issue description

Using catalyst.qjit, catalyst.grad, and functools.partial triggers AttributeError: 'functools.partial' object has no attribute '__name__' error.

from catalyst import qjit, grad
import functools

def fn(x, y):
    return x * y

partial_fn = functools.partial(fn, y= 1)

@qjit
def grad_partial_fn(x):
    return grad(partial_fn)(x)

grad_partial_fn(0.3)

Tracebacks

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[1], line 34
     31     updates, opt_state = optimizer.update(grad_circuit, opt_state)
     33 opt_state = optimizer.init(params)
---> 34 step(opt_state, params)

File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Workspace/catalyst-clone/frontend/catalyst/jit.py:513, in QJIT.__call__(self, *args, **kwargs)
    510 if EvaluationContext.is_tracing():
    511     return self.user_function(*args, **kwargs)
--> 513 requires_promotion = self.jit_compile(args)
    515 # If we receive tracers as input, dispatch to the JAX integration.
    516 if any(isinstance(arg, jax.core.Tracer) for arg in tree_flatten(args)[0]):

File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Workspace/catalyst-clone/frontend/catalyst/jit.py:582, in QJIT.jit_compile(self, args)
    578 # Capture with the patched conversion rules
    579 with Patcher(
    580     (ag_primitives, "module_allowlist", self.patched_module_allowlist),
    581 ):
--> 582     self.jaxpr, self.out_treedef, self.c_sig = self.capture(args)
    584 self.mlir_module, self.mlir = self.generate_ir()
    585 self.compiled_function, self.qir = self.compile()

File ~/Workspace/catalyst-clone/frontend/catalyst/debug/instruments.py:143, in instrument.<locals>.wrapper(*args, **kwargs)
    140 @functools.wraps(fn)
    141 def wrapper(*args, **kwargs):
    142     if not InstrumentSession.active:
--> 143         return fn(*args, **kwargs)
    145     with ResultReporter(stage_name, has_finegrained) as reporter:
    146         fn_results, wall_time, cpu_time = time_function(fn, args, kwargs)

File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Workspace/catalyst-clone/frontend/catalyst/jit.py:640, in QJIT.capture(self, args)
    634 full_sig = merge_static_args(dynamic_sig, args, static_argnums)
    636 with Patcher(
    637     (qml.QNode, "__call__", QFunc.__call__),
    638 ):
    639     # TODO: improve PyTree handling
--> 640     jaxpr, treedef = trace_to_jaxpr(
    641         self.user_function, static_argnums, abstracted_axes, full_sig, {}
    642     )
    644 return jaxpr, treedef, dynamic_sig

File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs)
     54     s_caller = "::L".join(
     55         [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]]
     56     )
     57     lgr.debug(
     58         f"Calling {f_string} from {s_caller}",
     59         **_debug_log_kwargs,
     60     )
---> 61 return func(*args, **kwargs)

File ~/Workspace/catalyst-clone/frontend/catalyst/jax_tracer.py:369, in trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs)
    364     with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION):
    365         make_jaxpr_kwargs = {
    366             "static_argnums": static_argnums,
    367             "abstracted_axes": abstracted_axes,
    368         }
--> 369         jaxpr, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
    371 return jaxpr, out_treedef

File ~/Workspace/catalyst-clone/frontend/catalyst/jax_extras/tracing.py:410, in make_jaxpr2.<locals>.make_jaxpr_f(*args, **kwargs)
    408     f, out_tree_promise = flatten_fun(f, in_tree)
    409     f = annotate(f, in_type)
--> 410     jaxpr, output_type, consts = trace_to_jaxpr_dynamic2(f)
    411 closed_jaxpr = DynshapedClosedJaxpr(jaxpr, consts, output_type)
    412 return closed_jaxpr, out_tree_promise()

File ~/.local/lib/python3.10/site-packages/jax/_src/profiler.py:336, in annotate_function.<locals>.wrapper(*args, **kwargs)
    333 @wraps(func)
    334 def wrapper(*args, **kwargs):
    335   with TraceAnnotation(name, **decorator_kwargs):
--> 336     return func(*args, **kwargs)
    337   return wrapper

File ~/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:2324, in trace_to_jaxpr_dynamic2(fun, debug_info)
   2322 with core.new_main(DynamicJaxprTrace, dynamic=True) as main:  # type: ignore
   2323   main.jaxpr_stack = ()  # type: ignore
-> 2324   jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   2325   del main, fun
   2326 return jaxpr, out_type, consts

File ~/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:2339, in trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   2337 in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)
   2338 in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
-> 2339 ans = fun.call_wrapped(*in_tracers_)
   2340 out_tracers = map(trace.full_raise, ans)
   2341 jaxpr, out_type, consts = frame.to_jaxpr2(out_tracers)

File ~/.local/lib/python3.10/site-packages/jax/_src/linear_util.py:191, in WrappedFun.call_wrapped(self, *args, **kwargs)
    188 gen = gen_static_args = out_store = None
    190 try:
--> 191   ans = self.f(*args, **dict(self.params, **kwargs))
    192 except:
    193   # Some transformations yield from inside context managers, so we have to
    194   # interrupt them before reraising the exception. Otherwise they will only
    195   # get garbage-collected at some later time, running their cleanup tasks
    196   # only after this exception is handled, which can corrupt the global
    197   # state.
    198   while stack:

Cell In[1], line 30, in step(opt_state, theta)
     28 @qjit
     29 def step(opt_state, theta):
---> 30     val, grad_circuit = cost(theta), grad(cost)(theta)
     31     updates, opt_state = optimizer.update(grad_circuit, opt_state)

File ~/Workspace/catalyst-clone/frontend/catalyst/api_extensions/differentiation.py:596, in Grad.__call__(self, *args, **kwargs)
    591 if EvaluationContext.is_tracing():
    592     assert (
    593         not self.grad_params.with_value
    594     ), "Tracing of value_and_grad is not implemented yet"
--> 596     fn = _ensure_differentiable(self.fn)
    598     args_data, in_tree = tree_flatten(args)
    599     grad_params = _check_grad_params(
    600         self.grad_params.method,
    601         self.grad_params.scalar_out,
   (...)
    606         self.grad_params.with_value,
    607     )

File ~/Workspace/catalyst-clone/frontend/catalyst/api_extensions/differentiation.py:710, in _ensure_differentiable(f)
    708     return f
    709 elif isinstance(f, Callable):  # Keep at the bottom
--> 710     return Function(f)
    712 raise DifferentiableCompileError(f"Non-differentiable object passed: {type(f)}")

File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:65, in log_string_debug_func.<locals>.wrapper_exit(*args, **kwargs)
     63 @wraps(func)
     64 def wrapper_exit(*args, **kwargs):
---> 65     output = func(*args, **kwargs)
     66     if lgr.isEnabledFor(log_level):  # pragma: no cover
     67         f_string = _get_bound_signature(*args, **kwargs)

File ~/Workspace/catalyst-clone/frontend/catalyst/jax_tracer.py:108, in Function.__init__(self, fn)
    105 @debug_logger_init
    106 def __init__(self, fn):
    107     self.fn = fn
--> 108     self.__name__ = fn.__name__

AttributeError: 'functools.partial' object has no attribute '__name__'