@qjit(static_argnums=[1])
def f(y, x=9):
if x < 10:
return x + y
return 42000
res = f(20)
print(res)
Traceback (most recent call last):
File "/home/paul.wang/catalyst_new/catalyst/static_argname.py", line 24, in <module>
res = f(20)
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 513, in __call__
requires_promotion = self.jit_compile(args, **kwargs)
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 584, in jit_compile
self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
return fn(*args, **kwargs)
File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/jit.py", line 631, in capture
verify_static_argnums(args, self.compile_options.static_argnums)
File "/home/paul.wang/catalyst_new/catalyst/frontend/catalyst/tracing/type_signatures.py", line 113, in verify_static_argnums
raise CompileError(msg)
catalyst.utils.exceptions.CompileError: argnum 1 is beyond the valid range of [0, 1).
Note that jax works:
@partial(jax.jit, static_argnums=[1])
def f(y, x=9):
if x < 10:
return x + y
return 42000
res = f(20)
print(res)
Note that jax works: