patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.05k stars 49 forks source link

Bug with default argument binding #206

Closed jaraujo98 closed 1 month ago

jaraujo98 commented 1 month ago

jaxtyping fails to detect default arguments when they are not provided explicitly. See the following MRE:

import numpy as np
from jaxtyping import Float, jaxtyped
from typeguard import typechecked as typechecker

class DefaultFactory:
    def __init__(self, n=1):
        self.n = n

@jaxtyped(typechecker=typechecker)
def function_with_default(
    positional_argument, default_argument: DefaultFactory = DefaultFactory(5)
) -> Float[np.ndarray, "positional_argument {default_argument.n}"]:
    return np.random.rand(positional_argument, default_argument.n)

function_with_default(3)

"""
Errors with
jaxtyping.AnnotationError: Cannot process symbolic axis 'default_argument.n' as some axis names have not been processed. In practice you should usually only use symbolic axes in annotations for return types, referring only to axes annotated for arguments.
"""

After digging a bit through the jaxtyping code, I found that the issue comes from the fact that the call to bound = param_signature.bind(*args, **kwargs) in jaxtyping._decorator.py:411 ignores default arguments. Manually patching bound with the default values in a debugger fixes the issue.

patrick-kidger commented 1 month ago

Ah! Perhaps we need a bound.apply_defaults() call here. I'd be happy to take a PR on this.

jaraujo98 commented 1 month ago

Oh, wow, I knew it should be something simple, but not this simple ahahah. I'll test it and submit a PR.