jaxtyping fails to detect default arguments when they are not provided explicitly. See the following MRE:
import numpy as np
from jaxtyping import Float, jaxtyped
from typeguard import typechecked as typechecker
class DefaultFactory:
def __init__(self, n=1):
self.n = n
@jaxtyped(typechecker=typechecker)
def function_with_default(
positional_argument, default_argument: DefaultFactory = DefaultFactory(5)
) -> Float[np.ndarray, "positional_argument {default_argument.n}"]:
return np.random.rand(positional_argument, default_argument.n)
function_with_default(3)
"""
Errors with
jaxtyping.AnnotationError: Cannot process symbolic axis 'default_argument.n' as some axis names have not been processed. In practice you should usually only use symbolic axes in annotations for return types, referring only to axes annotated for arguments.
"""
After digging a bit through the jaxtyping code, I found that the issue comes from the fact that the call to bound = param_signature.bind(*args, **kwargs) in jaxtyping._decorator.py:411 ignores default arguments. Manually patching bound with the default values in a debugger fixes the issue.
jaxtyping
fails to detect default arguments when they are not provided explicitly. See the following MRE:After digging a bit through the
jaxtyping
code, I found that the issue comes from the fact that the call tobound = param_signature.bind(*args, **kwargs)
injaxtyping._decorator.py:411
ignores default arguments. Manually patchingbound
with the default values in a debugger fixes the issue.