Closed tzunghanjuang closed 3 months ago
Using catalyst.qjit, catalyst.grad, and functools.partial triggers AttributeError: 'functools.partial' object has no attribute '__name__' error.
catalyst.qjit
catalyst.grad
functools.partial
AttributeError: 'functools.partial' object has no attribute '__name__'
from catalyst import qjit, grad import functools def fn(x, y): return x * y partial_fn = functools.partial(fn, y= 1) @qjit def grad_partial_fn(x): return grad(partial_fn)(x) grad_partial_fn(0.3)
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) Cell In[1], line 34 31 updates, opt_state = optimizer.update(grad_circuit, opt_state) 33 opt_state = optimizer.init(params) ---> 34 step(opt_state, params) File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs) 54 s_caller = "::L".join( 55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]] 56 ) 57 lgr.debug( 58 f"Calling {f_string} from {s_caller}", 59 **_debug_log_kwargs, 60 ) ---> 61 return func(*args, **kwargs) File ~/Workspace/catalyst-clone/frontend/catalyst/jit.py:513, in QJIT.__call__(self, *args, **kwargs) 510 if EvaluationContext.is_tracing(): 511 return self.user_function(*args, **kwargs) --> 513 requires_promotion = self.jit_compile(args) 515 # If we receive tracers as input, dispatch to the JAX integration. 516 if any(isinstance(arg, jax.core.Tracer) for arg in tree_flatten(args)[0]): File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs) 54 s_caller = "::L".join( 55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]] 56 ) 57 lgr.debug( 58 f"Calling {f_string} from {s_caller}", 59 **_debug_log_kwargs, 60 ) ---> 61 return func(*args, **kwargs) File ~/Workspace/catalyst-clone/frontend/catalyst/jit.py:582, in QJIT.jit_compile(self, args) 578 # Capture with the patched conversion rules 579 with Patcher( 580 (ag_primitives, "module_allowlist", self.patched_module_allowlist), 581 ): --> 582 self.jaxpr, self.out_treedef, self.c_sig = self.capture(args) 584 self.mlir_module, self.mlir = self.generate_ir() 585 self.compiled_function, self.qir = self.compile() File ~/Workspace/catalyst-clone/frontend/catalyst/debug/instruments.py:143, in instrument.<locals>.wrapper(*args, **kwargs) 140 @functools.wraps(fn) 141 def wrapper(*args, **kwargs): 142 if not InstrumentSession.active: --> 143 return fn(*args, **kwargs) 145 with ResultReporter(stage_name, has_finegrained) as reporter: 146 fn_results, wall_time, cpu_time = time_function(fn, args, kwargs) File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs) 54 s_caller = "::L".join( 55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]] 56 ) 57 lgr.debug( 58 f"Calling {f_string} from {s_caller}", 59 **_debug_log_kwargs, 60 ) ---> 61 return func(*args, **kwargs) File ~/Workspace/catalyst-clone/frontend/catalyst/jit.py:640, in QJIT.capture(self, args) 634 full_sig = merge_static_args(dynamic_sig, args, static_argnums) 636 with Patcher( 637 (qml.QNode, "__call__", QFunc.__call__), 638 ): 639 # TODO: improve PyTree handling --> 640 jaxpr, treedef = trace_to_jaxpr( 641 self.user_function, static_argnums, abstracted_axes, full_sig, {} 642 ) 644 return jaxpr, treedef, dynamic_sig File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:61, in log_string_debug_func.<locals>.wrapper_entry(*args, **kwargs) 54 s_caller = "::L".join( 55 [str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]] 56 ) 57 lgr.debug( 58 f"Calling {f_string} from {s_caller}", 59 **_debug_log_kwargs, 60 ) ---> 61 return func(*args, **kwargs) File ~/Workspace/catalyst-clone/frontend/catalyst/jax_tracer.py:369, in trace_to_jaxpr(func, static_argnums, abstracted_axes, args, kwargs) 364 with EvaluationContext(EvaluationMode.CLASSICAL_COMPILATION): 365 make_jaxpr_kwargs = { 366 "static_argnums": static_argnums, 367 "abstracted_axes": abstracted_axes, 368 } --> 369 jaxpr, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs) 371 return jaxpr, out_treedef File ~/Workspace/catalyst-clone/frontend/catalyst/jax_extras/tracing.py:410, in make_jaxpr2.<locals>.make_jaxpr_f(*args, **kwargs) 408 f, out_tree_promise = flatten_fun(f, in_tree) 409 f = annotate(f, in_type) --> 410 jaxpr, output_type, consts = trace_to_jaxpr_dynamic2(f) 411 closed_jaxpr = DynshapedClosedJaxpr(jaxpr, consts, output_type) 412 return closed_jaxpr, out_tree_promise() File ~/.local/lib/python3.10/site-packages/jax/_src/profiler.py:336, in annotate_function.<locals>.wrapper(*args, **kwargs) 333 @wraps(func) 334 def wrapper(*args, **kwargs): 335 with TraceAnnotation(name, **decorator_kwargs): --> 336 return func(*args, **kwargs) 337 return wrapper File ~/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:2324, in trace_to_jaxpr_dynamic2(fun, debug_info) 2322 with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore 2323 main.jaxpr_stack = () # type: ignore -> 2324 jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info) 2325 del main, fun 2326 return jaxpr, out_type, consts File ~/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:2339, in trace_to_subjaxpr_dynamic2(fun, main, debug_info) 2337 in_tracers = _input_type_to_tracers(trace.new_arg, in_avals) 2338 in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep] -> 2339 ans = fun.call_wrapped(*in_tracers_) 2340 out_tracers = map(trace.full_raise, ans) 2341 jaxpr, out_type, consts = frame.to_jaxpr2(out_tracers) File ~/.local/lib/python3.10/site-packages/jax/_src/linear_util.py:191, in WrappedFun.call_wrapped(self, *args, **kwargs) 188 gen = gen_static_args = out_store = None 190 try: --> 191 ans = self.f(*args, **dict(self.params, **kwargs)) 192 except: 193 # Some transformations yield from inside context managers, so we have to 194 # interrupt them before reraising the exception. Otherwise they will only 195 # get garbage-collected at some later time, running their cleanup tasks 196 # only after this exception is handled, which can corrupt the global 197 # state. 198 while stack: Cell In[1], line 30, in step(opt_state, theta) 28 @qjit 29 def step(opt_state, theta): ---> 30 val, grad_circuit = cost(theta), grad(cost)(theta) 31 updates, opt_state = optimizer.update(grad_circuit, opt_state) File ~/Workspace/catalyst-clone/frontend/catalyst/api_extensions/differentiation.py:596, in Grad.__call__(self, *args, **kwargs) 591 if EvaluationContext.is_tracing(): 592 assert ( 593 not self.grad_params.with_value 594 ), "Tracing of value_and_grad is not implemented yet" --> 596 fn = _ensure_differentiable(self.fn) 598 args_data, in_tree = tree_flatten(args) 599 grad_params = _check_grad_params( 600 self.grad_params.method, 601 self.grad_params.scalar_out, (...) 606 self.grad_params.with_value, 607 ) File ~/Workspace/catalyst-clone/frontend/catalyst/api_extensions/differentiation.py:710, in _ensure_differentiable(f) 708 return f 709 elif isinstance(f, Callable): # Keep at the bottom --> 710 return Function(f) 712 raise DifferentiableCompileError(f"Non-differentiable object passed: {type(f)}") File ~/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py:65, in log_string_debug_func.<locals>.wrapper_exit(*args, **kwargs) 63 @wraps(func) 64 def wrapper_exit(*args, **kwargs): ---> 65 output = func(*args, **kwargs) 66 if lgr.isEnabledFor(log_level): # pragma: no cover 67 f_string = _get_bound_signature(*args, **kwargs) File ~/Workspace/catalyst-clone/frontend/catalyst/jax_tracer.py:108, in Function.__init__(self, fn) 105 @debug_logger_init 106 def __init__(self, fn): 107 self.fn = fn --> 108 self.__name__ = fn.__name__ AttributeError: 'functools.partial' object has no attribute '__name__'
Issue description
Using
catalyst.qjit
,catalyst.grad
, andfunctools.partial
triggersAttributeError: 'functools.partial' object has no attribute '__name__'
error.Tracebacks