patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.14k stars 59 forks source link

Functions without type hints and import hook #197

Closed nimashoghi closed 1 month ago

nimashoghi commented 6 months ago

Hi!

First of all, thanks for the awesome library. This library has made my code much more understandable, and the runtime type-checking with beartype has been immensely useful.

I'm currently working on an existing (PyTorch) codebase that did not previously use jaxtyping type hints, and I'm gradually adding type hints to areas that I work on. As a result, I have a handful of cases where I'm not using function argument/return type hints, but am rather using isinstance checks, e.g.,:


def my_func(x):
    x = ... # some operation I'm not touching
    # my code
    assert isinstance(x, Float[torch.Tensor, "bsz channels"])
    x = do_some_other_stuff(x)
    # ... and then the rest of the code for `my_func`

In these cases, due to the way the import hooking is currently set up, I'm running into some very strange and unexpected behavior. Specifically, it seems like axis bindings in these kinds of functions just get ignored and do not get registered in the memo_stack.

This seems to be because, in the case above, my_func does not have jaxtyping type hints in its args/return types and thus will not be registered using the import hook.

For now, I've patched jaxtypings' import hook code (_import_hook.py) to also register all functions with isinstance expressions:

def _has_isinstance(func_def):
    for node in ast.walk(func_def):
        if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "isinstance":
            return True
    return False

And then checking for this in JaxtypingTransformer, changing the following lines to:

    def visit_FunctionDef(self, node: ast.FunctionDef):
        has_annotated_args = any(arg for arg in node.args.args if arg.annotation)
        has_annotated_return = bool(node.returns)
        has_isinstance = _has_isinstance(node)
        if has_annotated_args or has_annotated_return or has_isinstance:

This is a hacky fix but works in my case. Would love to hear what your thoughts on fixing this would be (and if a similar fix is warranted for now).

Thanks!

patrick-kidger commented 6 months ago

Ah, good catch!

Maybe we should just remove the has_annotated_args or has_annotated_return call? I don't think it's important, it was just a minor efficiency thing.

I'd be happy to take a PR on this.

patrick-kidger commented 1 month ago

Closing as accomplished in #205, which corresponds to jaxtyping version 0.2.29.