PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
131 stars 29 forks source link

`qjit(static_argnums=...)` fails when the marks static argument has a default value #1163

Open paul0403 opened 1 week ago

paul0403 commented 1 week ago
@qjit(static_argnums=[1])
def f(y, x=9):
    if x < 10:
        return x + y
    return 42000

res = f(20)
print(res)
Traceback (most recent call last):
  File "/home/paul.wang/catalyst_new/catalyst/static_argname.py", line 24, in <module>
    res = f(20)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 513, in __call__
    requires_promotion = self.jit_compile(args, **kwargs)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 584, in jit_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 631, in capture
    verify_static_argnums(args, self.compile_options.static_argnums)
  File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/tracing/type_signatures.py", line 113, in verify_static_argnums
    raise CompileError(msg)
catalyst.utils.exceptions.CompileError: argnum 1 is beyond the valid range of [0, 1).

Note that jax works:

@partial(jax.jit, static_argnums=[1])
def f(y, x=9):
    if x < 10:
        return x + y
    return 42000

res = f(20)
print(res)
29
paul0403 commented 1 week ago

I think this is because the static_argnum tracks the argument indices at the call, not at the definition