Closed nimashoghi closed 1 month ago
Ah, good catch!
Maybe we should just remove the has_annotated_args or has_annotated_return
call? I don't think it's important, it was just a minor efficiency thing.
I'd be happy to take a PR on this.
Closing as accomplished in #205, which corresponds to jaxtyping version 0.2.29.
Hi!
First of all, thanks for the awesome library. This library has made my code much more understandable, and the runtime type-checking with beartype has been immensely useful.
I'm currently working on an existing (PyTorch) codebase that did not previously use jaxtyping type hints, and I'm gradually adding type hints to areas that I work on. As a result, I have a handful of cases where I'm not using function argument/return type hints, but am rather using
isinstance
checks, e.g.,:In these cases, due to the way the import hooking is currently set up, I'm running into some very strange and unexpected behavior. Specifically, it seems like axis bindings in these kinds of functions just get ignored and do not get registered in the
memo_stack
.This seems to be because, in the case above,
my_func
does not have jaxtyping type hints in its args/return types and thus will not be registered using the import hook.For now, I've patched jaxtypings' import hook code (
_import_hook.py
) to also register all functions withisinstance
expressions:And then checking for this in
JaxtypingTransformer
, changing the following lines to:This is a hacky fix but works in my case. Would love to hear what your thoughts on fixing this would be (and if a similar fix is warranted for now).
Thanks!