patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.16k stars 61 forks source link

_JaxtypingTransformer should decorate classes with beartype instead of methods #92

Open dkamm opened 1 year ago

dkamm commented 1 year ago

This seems like the correct thing to do? See https://beartype.readthedocs.io/en/latest/faq/#the-current-class and the following code sample

import ast
from jaxtyping.import_hook import _JaxtypingTransformer

transformer = _JaxtypingTransformer(typechecker="beartype.beartype")

source = """
from typing_extensions import Self

class Foo:
    @property
    def bar(self: Self) -> int:
        return 123
"""
tree = compile(source, filename="<string>", mode="exec", flags=ast.PyCF_ONLY_AST)
tree_transformed = transformer.visit(tree)
print(ast.unparse(tree_transformed))
ast.fix_missing_locations(tree_transformed)
code_transformed = compile(tree_transformed, filename="<ast>", mode="exec")
exec(code_transformed)

errors with

beartype.roar.BeartypeDecorHintPep673Exception: method __main__.Foo.bar() parameter "self" PEP 673 type hint "typing_extensions.Self" invalid outside @beartype-decorated class. PEP 673 type hints are valid only inside classes decorated by @beartype. If this hint annotates a method decorated by @beartype, instead decorate the class declaring this method by @beartype: e.g.,

    # Instead of decorating methods by @beartype like this...
    class BadClassIsBad(object):
        @beartype
        def awful_method_is_awful(self: Self) -> Self:
            return self

    # ...decorate classes by @beartype instead - like this!
    @beartype
    class GoodClassIsGood(object):
        def wonderful_method_is_wonderful(self: Self) -> Self:
            return self

This has been a message of the Bearhugger Broadcasting Service.
patrick-kidger commented 1 year ago

Right! I've been meaning to fix this for a while. (Thankfully beartype is normally happy with being decorated on methods, so this issue isn't high-priority.)

The problem is that (a) I'd need to think a lot more carefully about how to write the AST transformer, which is quite finickity, and (b) non-beartype runtime type checkers may prefer to have functions decorated, so we need to support both behaviours.

I'd be happy to accept a PR on this.

dkamm commented 1 year ago

Can take a look! Was thinking the same about part b

mjo22 commented 6 months ago

Hello! I was wondering if there is an update on this PR. I make heavy use of Self in my library and would like to use the the import hook with beartype, so deciding what the best course of action for myself is! Much appreciated in advance 😊

patrick-kidger commented 6 months ago

So the current state is essentially this comment. Which is to say, right now this still doesn't work, but we do we know how we want to fix it: our desired end-goal is for the jaxtyping import hook and the beartpe import hook to be usable together.

Rather than having jaxtyping in the driver's seat, (and it having to apply the @beartype.beartype decorator correctly, as discussed earlier in this thread), then this approach would allow each library to just do its own thing independently.

However, at the moment, when the beartype hook is used then the jaxtyping hook seems to be entirely ignored. :(

If you or anyone else would be interested in taking this one, then I'm sure either @leycec or I would be happy to receive a PR so that this combination plays nicely with each other!

(For anyone else reading this issue, the best practice for now is to do jaxtyping.install_import_hook(typechecker=beartype.beartype), and simply avoid using Self.)

mjo22 commented 6 months ago

This is very helpful to know! Thank you for the update, this sounds like a good solution.