patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.1k stars 56 forks source link

will runtime type checking go beyond function parameters and return type? #153

Open jeezrick opened 8 months ago

jeezrick commented 8 months ago

Great project, helps me understand DL code a lot.

I used it like this:

Patch_embed = Float32[torch.Tensor, f"B {PATCH_H} {PATCH_W} {PATCH_EMBED_DIM}"] 
Mlp_mid = Float32[torch.Tensor, f"B {PATCH_H} {PATCH_W} {MLP_HIDDEN}"]

...

    @jaxtyped(typechecker=beartype)
    def forward(self, x: Patch_embed) -> Patch_embed:
        x: Mlp_mid = self.act(self.lin1(x))
        return self.lin2(x)

but turns out, it doesn't do runtime type check on this x: Mlp_mid = self.act(self.lin1(x)) line. And this makes me feel insecure. So, my question is, will this feature be added in the future? Or is it in confilct with some design intention?

BTW, I mainly use it when I am trying to understand others code. But can I include it in production? How much does it slow down the training and inference?

patrick-kidger commented 8 months ago

So right now this will work if you use a manual isinstance check:

@jaxtyped(typechecker=beartype)
def forward(self, x: Patch_embed) -> Patch_embed:
    x: Mlp_mid = self.act(self.lin1(x))
    assert isinstance(x, Mlp_mid)
    return self.lin2(x)

but I agree that inserting this automatically would be an awesome feature to have.

I can see two possible ways we might add this:

Tagging @leycec -- I know you've tackled this problem before; do you have any thoughts on the matter?


BTW, I mainly use it when I am trying to understand others code. But can I include it in production? How much does it slow down the training and inference?

If you're using JAX, then it shouldn't slow runtime at all. There will only be a negligible amount of work at compile time, just to actually check that the shapes and dtypes of the arrays are what they say there are.

If you're using PyTorch/etc. then the amount of overhead hasn't really been benchmarked, so I'm less sure :)

leycec commented 8 months ago

...heh. The age-old PEP 526-compliant Annotated Variable Assignment Runtime Type-checking Problem, huh? That one just never gets old. So, thanks for pinging me on! I love this sort of hacky Python wrangling.

Sadly, you already know everything. This is like when Luke in Empire Strikes Back finally realizes that the little wizened green lizard blob thing is actually the most powerful surviving Jedi in the entire universe. You are that blob thing. Wait. That... no longer sounds complimentary. Let's awkwardly start over. </ahem>

Two Roads: One That Sucks and One That Sucks a Bit Less

As you astutely surmise, two sucky roads lie before you:

Relatedly, you almost certainly want to steal everything not nailed down be inspired by @beartype's own visit_AnnAssign() implementation. The proliferation of edge cases make the implementation non-trivial. Because @beartype and the Power of Asperger's, the code is outrageously commented. You can probably streamline that by... alot.

Let's Make a Deal: A Match Made in GitHub

Tangentially, would you like @beartype to automate your import hooks for you?

@beartype would be delighted to silently piggyback jaxtyping onto its own hunched back. @beartype already provides a rich constellation of import hook primitives (e.g., beartype_this_package(), beartype_packages(), beartyping()). Moreover, anyone who wants @beartype probably also wants jaxtyping. Moreover, it's optimally efficient for @beartype to circumvent all of the file I/O and CPU churn associated with the jaxtyping.install_import_hook() import hook by instead just directly applying your currently private (and thus horribly risky) jaxtyping._import_hook.JaxtypingTransformer subclass where @beartype applies its own beartype.claw._ast.clawastmain.BeartypeNodeTransformer subclass.

Specifically, @beartype would automate away everything by:

# In our private "beartype.claw._importlib._clawimpload` submodule:
...
        # Attempt to...
        try:
            # Defer optional dependency imports.
            from jaxtyping._import_hook import JaxtypingTransformer

            # AST transformer decorating typed callables and classes by "jaxtyping".
            #
            # Note that we intentionally pass *NO* third-party "typechecker" to avoid
            # redundantly applying the @beartype.beartype decorator to the same callables
            # and classes twice.
            ast_jaxtyper = JaxtypingTransformer(typechecker=None)

            # Abstract syntax tree (AST) modified by this transformer.
            module_ast = ast_jaxtyper.visit(module_ast)
        # If "jaxtyping" is currently unimportable, silently pretend everything is well.
        except ImportError:
            pass

        # AST transformer decorating typed callables and classes by @beartype.
        ast_beartyper = BeartypeNodeTransformer(
            conf_beartype=self._module_conf_beartype)

        # Abstract syntax tree (AST) modified by this transformer.
        module_ast_beartyped = ast_beartyper.visit(module_ast)

Of course, you don't need to actually depend upon or require @beartype in any way. No changes on needed on your end. Actually, one little change would improve sanity: if you wouldn't mind publicizing JaxtypingTransformer somewhere, @beartype can then just publicly import that and no longer worry about destroying itself.

That's it. Super-easy, honestly. Everyone currently getting @beartype would then get jaxtyping as well for free. Free is good. Let us give the good people what they want, @patrick-kidger. Smile, benevolent cat! Smile! :smile_cat:

Equinox: You Are Now Good to Go

Oh – and @beartype now officially supports Equinox. 2024: dis goin' be gud.

patrick-kidger commented 8 months ago

Thank you @leycec! Okay, looks like we're stuck with these two options. I'll have a think about what implementing these looks like. I'm not wild about either approach, they're both pretty magic... :D The beartype implementation in particular is a very useful reference.


As for having the AST transformers / import hooks work together... hmm, I think I'd need to understand this better still. I think right now what we've got aren't really directly compatible with each other. We might need some more changes elsewhere before this is doable.

To explain: jaxtyping doesn't really want you to use JaxtypingTransformer(typechecker=None). This is the old way of doing things.

As of fairly recently, we now do some fairly evil things under the hood, corresponding to passing typechecker=beartype.beartype. This terrifying block of code will transform

@jaxtyped(typechecker=beartype)
def foo(x: Ann) -> Ret:
    ... # do stuff

out = foo(bar)

into

@beartype
def foo_args(x: Ann):
    pass

try:
    foo_args(bar)
except Exception as e:
    raise ValueError("A helpful error message for the arguments") from e

out = foo(bar)

@beartype
def foo_ret(x: Ann) -> Ret:
    return out

try:
    foo_ret(bar)
except Exception as e:
    raise ValueError("A helpful error message for the return value") from e

where foo_args and foo_ret are actually functions dynamically generated at decoration time. Fun.

The reason for this is so that we can finally give those nice error messages about shapes and dtypes and what-not when something goes wrong. What this means is that if jaxtyping is being used, it's basically assuming that it's "in charge": it gets to decide how to report error messages, not beartype! (Although you can scroll up a bit in the traceback to see the underlying error message that was caught and attached as __cause__.)

This is actually pretty nice from the jaxtyping point-of-view. We don't have to worry about whether you're using beartype or typeguard or anything else: if the decorator raises an exception, we can use it.

So, what does this mean at this point? Frankly, I'm not 100% sure. :D It's definitely not optimal for jaxtyping to have to reimplement the magic required for checking statement annotations. Perhaps we could factor out all the import hook business into a shared third library? Possibly that leads to fire and explosions.


Equinox: hurrah! That's awesome. What an excellent Christmas gift.

leycec commented 8 months ago

Woooooooah. Indeed, I see terrifying – yet ultimately justifiable – shenanigans that sadden me. In the absence of a standardized plugin system for runtime-static type-checkers generically supported by both @beartype and typeguard, though, what else you gonna do? Right? So you force the issue by just doing everything yourself. Still, I feel a gnawing fear. That's some pretty gnarly AST munging just to circumvent the lack of a sane plugin API in other people's code.

@beartype and typeguard alike are clearly failing jaxtyping here. But... it's not just jaxtyping we're failing. It's any third-party package publishing custom type hint factories like jaxtyping.Float[...], really.

My only wish for 2024 is to stop failing jaxtyping. That, and for our Maine Coons to finally cough up their pernicious hairballs. Just give it up already, cats! :cat2:

Oh. Oh. But I Just Realized That...

The jaxtyping approach seems to conflict pretty hard with the @beartype approach. Both packages want be to in the driver's seat. Both packages want to apply their own AST transformations. If @beartype doesn't get to apply its AST transformations, then @beartype literally cannot support a growing laundry list of PEP standards. The @beartype.beartype decorator will then raise exceptions when presented with type hints conforming to those standards. This includes:

typeguard is even more intensely dependent on AST transformation. Far more intensely. Like... have you seen what agronholm is up to over there? He's begun refactoring typeguard hard-core into a full-featured static type-checker that runs at runtime via extremely non-trivial AST transformations. The old-school @typeguard.typechecked decorator doesn't seem to do too much anymore – not by compare to the typeguard AST transformation, anyway. Basically, typeguard now assumes usage of its import hook to deliver typeguard-instrumented ASTs across entire packages.

@beartype is no different. We're just taking a longer and more circuitous route to get to the same place. The above list will definitely grow over time. Indeed, the mere existence of the above list suggests that nobody should call type-checking decorators in 2024. They're obsolete. They're insufficient. And they'll probably be deprecated by both @beartype and typeguard at some point in the near future. Probably not this year – but probably sooner than is comfortable for anybody.

Oh, Boy. It Comes to This.

So. I grok the @jaxtyping.jaxtyped approach. That totally makes sense for your use case of delivering human-readable type-checking violations. But you'll still need some sort of ast.NodeTransformer subclass to type-check (...waitforit) PEP 526-compliant annotated variable assignments. Whatever subclass that is, that subclass will probably also inject AST nodes decorating classes and callables with @jaxtyping.jaxtyped, right?

So. @beartype's own import hooks can still detect and apply your ast.NodeTransformer subclass after applying its own ast.NodeTransformer subclass. In this case, @beartype would probably want to stop decorating classes and callables with @beartype.beartype and let @jaxtyping.jaxtyped handle that. Of course, even that approach has problematic hot spots: namely, configuration. @beartype's AST transformer would need some way of passing on the optional keyword-only @beartype.beartype(conf: BeartypeConf) parameter through jaxtyping's AST transformer and into @beartype itself. Also, @beartype's AST transformer preferentially decorates classes rather than callables. "Ugh!", I say.

All of this would be a whole lot easier if @beartype just hurried up already and provided a public API for jaxtyping to generate human-readable type-checking violations from jaxtyping type hints. In that case, all that insanity cleverness surrounding @jaxtyping.jaxtyped would go away (with respect to @beartype, anyway).

Ultimately, @beartype, typeguard, and jaxtyping all need to "just get along" by supporting some form of ast.NodeTransformer composition. You might get me to agree to a lower-level shared third library responsible for registering type-checking import hooks, but you will never get agronholm to agree to that. Dude definitely walks his own way. Nothing wrong with that, either. I respect agronholm immensely. Stubbornness solves all problems.

There's much to ponder. Yet, the will to code big ol' plugin architectures is weak. :face_exhaling:

Crazy Idea Is Crazy

One crazy idea would be to just fold portions of jaxtyping into @beartype. I'd happily grant you push access to @beartype and free reign to inject whatever JAX-specific shambolic typing horrors wonders of joy you like into the main codebase. I trust you implicitly. Under this governance model:

Consider it! Could be fun. Or... it could be a living Hell. :smiling_imp: :open_mouth:

leycec commented 8 months ago

Wait. I'd still love to add you as a collaborator to @beartype, @patrick-kidger – but I've realized the painfully obvious while cross-country skiing the muddy and rock-strewn trails of backwoods Ontario. With great snow comes great enlightenment.

@beartype 0.17.0 will support the plugin API that jaxtyping needs, wants, and deserves already. Thankfully, this is trivial for both us and you. Whenever an object fails a type-check against a normal class, @beartype will just look for a def __instancecheck_str__(cls, obj: typing.Any) -> str: dunder method on the metaclass of that class; if found, @beartype will call that method to generate a human-readable exception message unique to that class. Usage parallels the existing def __instancecheck__(cls, obj: typing.Any) -> bool: dunder method that you're almost certainly using everywhere.

When @beartype 0.17.0 does this:

In short, jaxtyping is the best.

patrick-kidger commented 8 months ago

Okay, lots of interesting ideas here! Settle in, I have a wall of text of my own.

Scope

I think we all agree that the ideal future would be for jaxtyping only to provide:

(a) the type hints; (b) the @jaxtyped decorator around a function (providing the dynamic context within which the shape-checks are performed)

and leave all type-checking and error-reporting to beartype/typeguard.

Future changes in jaxtyping

It's great that the custom error message reporting is now available in beartype! In that case, I think I see the same easy-peasy future for us as you do. I'll implement the __instancecheck_str__ method in jaxtyping, and then from that point the old-style double-decorator

@jaxtyped(typechecker=None)
@beartype
def foo(...)

approach to things should just work.
hopefully

Import hooks

As a practical matter, I expect we should be able to arrange to add both decorators using two separate import hooks.

In particular neither package is in the driver's seat. We're just both adding our own decorators via import hooks.

Actually, I realise beartype will be doing something slightly different to just adding a decorator -- you'll be adding isinstance checks elsewhere as per PEP 526, and might even do that for the arguments and return values too. jaxtyping don't care! All that matters is that all isinstance checks happen within the dynamic context of the jaxtyped decorator; that's the only thing we need to be sure to make happen.

Why I like this

(a) This means that we're not coupling beartype and jaxtyping together in any meaningful way. The only contact point is the __instancecheck_str__ method. This is good! Coupled code is a source of nightmares. (But thank you for the offer of push access to beartype!)

(b) As a practical matter, jaxtyping+v2-of-typeguard is actually a very popular combination, and this approach means that we won't be risking breaking that either. (For the sake of such approaches, I'm afraid I'll still keep the perfidious evil of jaxtyped(typechecker=...), even if it won't be the recommend path when used with beartype.)

(c) It fixes up #92, as you note!

(d) It makes it possible to use jaxtyping with PEP 526. The original purpose of this issue, lest we forget... :D

Questions for you

  1. Is naming the method __instancecheck_str__ going to be robust to future changes in Python? What we're doing here is not such an unusual thing to want, so something like Python 3.18 might add this as a method then. Could it be called __beartype_instancecheck_str__ for additional safety?

  2. I'm finding the current beartype import hooks a bit... intense. There appear to be:

    • 5 different options (beartype_this_package, beartype_package, beartype_packages, beartyping, beartype_all);
    • beartype_package and beartype_packages basically overlap;
    • is there an actually use-case for beartype_package(s) over beartype_this_package?
    • beartype_all is probably too dangerous to ever be used;
    • beartyping isn't actually documented besides the one example;
    • beartyping doesn't appear to take a package argument name. Is the idea that it grabs the __module__ from higher up the stack? (I guarantee I can break that with a delayed wrapper.)

    I'm hoping for a future in which we can recommend something consistent like:

    # foo/__init__.py
    with jaxtyping.install_import_hook("foo"), beartype.install_import_hook("foo"):
        from . import bar
        from . import baz

    or

    # foo/__init__.py
    jaxtyping.jaxtype_this_package()
    beartype.beartype_this_packge()
    from . import bar
    from . import baz

    Likewise it'd be pretty neat if beartype had a pytest hook or an IPython extension.

    (You can tell I like the selection of largely-orthogonal hook options that jaxtyping provides.)

  3. Entirely unrelated, but whilst I have you here: what's the status of O(n) checking in beartype? At least for JAX this is definitely the choice we want, as all checking happens just once at trace time and never at runtime, so the overhead isn't important.

leycec commented 8 months ago

Excellence! beartype/beartype@6b3aadfff7f9e4ef1ccde is the first step on this tumultuous voyage into the unknown. Your questions are, of course, apropos and almost certainly highlight deficiencies in my worldview. To wit:

Is naming the method __instancecheck_str__() going to be robust to future changes in Python?

...heh. Let's pretend it is. Actually, the hope is that this will eventually metastasize into an actual PEP standard. In the meanwhile, I hope that somebody who is not me will market this to agronholm himself at the typeguard issue tracker. Ideally, all runtime type-checkers should transparently support __instancecheck_str__(). Right? It's just the right thing to do.

Injecting the suspiciously @beartype-specific substring "__beartype__" into this API would probably inhibit that. Likewise, @beartype would probably never support something named __typeguard_instancecheck_str__(). It's a fragile human ego thing, I think. :sweat_smile:

I'm finding the current beartype import hooks a bit... intense.

(╯°□°)╯︵ ┻━┻

5 different options

you're not wrong

beartype_package and beartype_packages basically overlap;

you're not wrong

beartype_all is probably too dangerous to ever be used;

you're not wrong

beartyping isn't actually documented besides the one example;

you're not wrong

Wait... I'm beginning to detect a deterministically repeatable pattern here. Allow me to now try but fail to explain:

Assuming beartyping actually does what it says it does, this presumably works: totally doesn't

# foo/__init__.py
with jaxtyping.install_import_hook("foo"), beartype.beartyping():
    from . import bar
    from . import baz

This definitely should work, too:

# foo/__init__.py
jaxtyping.jaxtype_this_package()  # <-- dis iz h0t
beartype.beartype_this_package()
from . import bar
from . import baz

Likewise it'd be pretty neat if beartype had a pytest hook...

Yes! So much, "Yes!" I actually implemented a pytest plugin a month ago. But then I got bored, never tested anything, and never stuffed that into an actual GitHub repository. @beartype even has an existing pytest-beartype repository – but it's still empty. Somebody else was supposed to do all that. But then the holidays and presumably video games happened.

Long story short: "Nobody did nuffin'." :face_exhaling:

...what's the status of O(n) checking in beartype?

Given that @beartype still fails to deeply type-check most standard container types like dict and set in O(1) time, let's just quietly accept that O(n) type-checking is a dream beyond the veil of sleep.

That said, does this actually intersect with jaxtyping? From @beartype's perspective, isn't a JAX array just this monolithic object that @beartype farms off to jaxtyping?

Uhm... Err...

Oh. Wait. I never actually implemented support for jaxtyping in @beartype. Welp, that's awkward. Basically, @beartype should be doing the same thing that @beartype currently does for the third-party Pandera type hint package: namely, @beartype should silently ignore all type hints from jaxtyping. Is that right?

Previously, I'd assumed that jaxtyping was using the standard __instancecheck__() mechanism for its type hints and that @beartype didn't have to worry about anything. But that definitely doesn't seem to be the case. Is that right? Should @beartype not be type-checking objects annotated with jaxtyping type hints via the standard isinstance() builtin?

I'm not even necessarily clear what "trace time" is, frankly. My wife and I are currently sloooowly migrating our data science pipeline from Ye Ol' Mostly Single-threaded NumPy and SciPy World to Speedy Gonzalez JAX World. I must confess that I am dumb, in short.

patrick-kidger commented 8 months ago
  1. On naming: hmm, I don't love it but I see the point that this may help adoption.

  2. On the number of import hooks: yeah, I'm not suggesting removing these, what with backward compatibility. I just don't love how complicated the documentation is with all these options. What I've done for things like this before is to just remove the less-favoured options from the documentation whilst keeping them around in the code. The goal being to keep a well-lit path for new users.

  3. On relative imports: I suspect relative imports will work with beartyping just fine. (Personally I use them everywhere: I much prefer the style of relative imports from within a library, and absolute imports to pull in other libraries. Has some nice benefits like not making assumptions about the PYTHONPATH etc.)

  4. On beartyping (round two): what you've got seems to break transitive imports. That is, if one of those libraries happens to import e.g. scipy, then beartype will attempt to decorate all of scipy too. (This is why jaxtyping.install_import_hook takes a package name.)

  5. On pytest + O(n) checking: fair enough! Never enough time in an open-source maintainer's life to implement everything. I'll ~bug~ ask you about these some other time.

  6. beartype already does the correct thing for jaxtyping! Just call isinstance on the type hint like any other custom class. This is what jaxtyping is designed for: just to provide hints for use as isinstance(some_array, Float[Array, "foo bar"]).

  7. To clear up what "trace time" means -- this refers to actually running the Python code for a JAX program. JAX uses JIT-compiler, so as it runs through the code it records every array operation that occurs. At that point it's built up a DAG of all the array-based operations, compiles them... and never runs the Python code again. The compiled version is what is used at runtime. This makes runtime type checking and JAX be a great fit for each other, as we only pay the cost of runtime type checking just once during this tracing, and not at all at "actual" runtime.

leycec commented 8 months ago

Oh, fascinating. I just love me some JIT + DAG action. It looks like @beartype is in good paws here.

On beartyping (round two): what you've got seems to break transitive imports.

...heh. I was wondering when you'd catch that. This is actually intentional, because the current implementation of beartyping() trivially reduces to:

from contextlib import contextmanager

@contextmanager
def beartyping(conf: BeartypeConf = BeartypeConf()) -> Iterator[None]:
    try:
        beartype_all(conf=conf)
        yield
    finally:
        undo_beartype_all()

That's... it. You are thinking:

But I Hate beartype_all().

I used to think similarly up until five minutes ago. Then I slipped off a toilet while hanging a clock shaped like a hibernating bear, banged my head on a towel rack shaped like a spawning salmon, and... I saw it. The Flux Beartyper:

# In your "{your_package}.__init__":
from beartype import BeartypeConf
from beartype.claw import beartype_this_package, beartype_all

beartype_this_package()
beartype_all(conf=BeartypeConf(
   violation_param_type=UserWarning, 
   violation_return_type=UserWarning, 
))

I've never documented this, but @beartype import hooks are actually backed by a pure-Python trie data structure mapping from package names to @beartype configurations. That's sorta why the whole thing took so long for me to implement. In the above example that may very well become the lead goto example on @beartype's front page, we compel @beartype to:

The Flux Beartyper is thus the superset of mypy and pyright: it does everything those guys do (complain about everything), while also doing something those guys can never do (actually enforce something). Moreover, it selectively enforces those things on the one thing you have under your total control: your own package.

You are thinking:

"I hate the Flux Beartyper."

Tough crowd! Oh – and I've fully exercised __instancecheck_str__() with unit tests. All's well that passes. This is your final API for raising human-readable violation messages:

from beartype import beartype

class MetaclassOfMuhClass(type):
    def __instancecheck_str__(cls, obj: object) -> str:
        return (
            f'{repr(obj)} has disappointed {repr(cls)}... '
            f'for the last time.'
        )

class MuhClass(object, metaclass=MetaclassOfMuhClass):
    pass

@beartype
def muh_func(muh_obj: MuhClass) -> None:
    pass

muh_func("Some strings, you just can't reach.")

...which now raises the "human"-readable violation message:

beartype.roar.BeartypeCallHintParamViolation: Function
__main__.muh_func() parameter muh_obj="Some strings, you just can't
reach." violates type hint <class '__main__.MuhClass'>, as "Some
strings, you just can't reach." has disappointed <class
'__main__.MuhClass'>... for the last time.

Please think of a better name that Pydantic and typeguard will both accept. I'd love to satisfy everybody's wildest API passions. Is your main objection forward compatibility concerns with future Python releases? Totally valid. In theory, the "_" character between the instancecheck and str should guard a bit against that; all official dunder methods contain no internal "_" characters, so far as I know. I shrug noncommittally.

Thanks again for your painstaking interest in all this. It's a delight to dissect API design with a fellow gearhead.

@beartype 0.17.0 should drop this Friday or thereabouts. Let us light a candle in our hearts for AI QA. :candle: :robot: :candle:

avolchek commented 5 months ago

Hi! Just wanted to share some information about this issue. After __instancecheck_str__ support was added to both jaxtyping and beartype, I was able to quite easily monkey-patch beartype's import hook so that it uses jaxtyping's memo-manipulation routines. I then used beartype's hook instead of jaxtyping's, and it seems that it works correctly and checks local variable assignments like in the OP's post.

https://gist.github.com/avolchek/6ac328c435e2bd584c2722ce058de824

leycec commented 5 months ago

@avolchek: That's hilarious, horrifying, and hellish: The Three Holy H's. That's also guaranteed to break, because:

Seriously, though. That's gonna break. In fact, that probably already broke. I'm on the cusp of releasing @beartype 0.18.0. GitHub only knows where jaxtyping is going. Possibly, they don't need roads there.

I have now excoriated you, @avolchek. Now allow me to praise you! Yes, praise! Because that's actually an ingenious reverse engineering of two excruciatingly non-trivial codebases. Thanks to your solemn sacrifice and soon-to-be-broken codebase, I can happily accept that integrating @beartype + jaxtyping is actually trivial. All @beartype needs to do is:

  1. Detect jaxtyping type hints. Probably trivial. If @beartype can detect Pandera type hints (i.e., The Ultimate Pain (TUP)), @beartype can detect literally anything.
  2. On detecting a callable annotated by one or more jaxtyping type hints, dynamically inject jaxtyping setup and teardown logic into the type-checking wrapper function that @beartype generates as follows:
import jaxtyping

def __beartype_wrapper(...):
    jaxtyping._storage.push_shape_memo({})

    ...  # <-- *BEARTYPE MAGIC HAPPENS HERE*
    __beartype_pith_0 = ...  # <-- *MORE MAGICAL UNICORNS ERUPT*

    jaxtyping._storage.pop_shape_memo()
    return __beartype_pith_0

Clearly, that's trivial. Clearly, @patrick-kidger is also grinding his teeth into stubs. Don't worry! I won't do anything without your consent. For one, I'm lazy. For another, I'd have to violate privacy encapsulation. For a final one, there's no guarantee any of this dark magic will continue to behave itself in perpetuity without everyone's explicit consent and continual agreement.

How do you feel about this sort of horror, @patrick-kidger? I know you. You're like me – only stronger, fitter, and more likely to survive the collapse of Canada's rickety maple syrup market. You really want to remain "in the driver's seat." But @beartype can probably solve all your problems with just a trivial amount of integration, glue, sputum, white-knuckle grips on the keyboard, and eyes-wide-shut commits guaranteed to blow up. Should @beartype make overtures towards doing something like this or should I just back away slowly from the keyboard before anybody gets hurt and feels bad?

@patrick-kidger: Unrelatedly, Google now claims that the title for the https://github.com/patrick-kidger/jaxtyping front page is "Piotr Kaminski." When I then google the name "Piotr Kaminski," I see a concerning New York lawyer manbro who can only charitably be referred to as an "American Psycho"-era Christian Bale stunt double. I don't know what's happening. Probably, something is happening. My search for the truth bottomed out here. :laughing:

avolchek commented 5 months ago

That's also guaranteed to break

For sure :)

Detect jaxtyping type hints. Probably trivial. If https://github.com/beartype can detect Pandera type hints (i.e., The Ultimate Pain (TUP)), https://github.com/beartype can detect literally anything.

I'm not sure that you need to add jaxtyping stuff into the type-checking wrapper only if you have jaxtyping type-hints. E. g. in

def foo():
   kek: Float32[torch.Tensor, "a b c"] = ...

def bar(x: Float32[torch.Tensor, f"b w h c"]):
   ...
   foo()
   ...

isinstance check for kek will "see" value of c from x shape. In contrast jaxtyping import hook adds {push,pop}_shape_memo to every function, regardless of whether it has jaxtying hints or not

patrick-kidger commented 5 months ago

Haha, @avolchek, I'm impressed!

Although, with the latest jaxtyping and beartype releases, then in theory this should "just work", just by applying both of our import hooks.

Indeed I would like to update the documentation to reflect this, however... it seems like this doesn't work!

# entry_point.py
from beartype.claw import beartype_package
from jaxtyping import install_import_hook

beartype_package("foo")
with install_import_hook("foo", typechecker=None):
    import foo

# foo.py
import jax
from jaxtyping import Array, Float

def foo():
    x: Float[Array, "size"] = jax.numpy.ones(3)
    y: Float[Array, "size"] = jax.numpy.ones(4)  # this is an error!

foo()  # but this does not raise an error!

I've tried both orders for installing the import hooks. I'm not sure what's going on here right now.

(Notably, however, something like x: Float[Array, ""] = jax.numpy.ones(3) will fail, since it's independent of the jaxtyping context.)

I think if we're to direct our efforts, it should be to be sure that this approach is usable! This would allow both packages to do what they do best, without needing to interface with each other at all.

patrick-kidger commented 5 months ago

Google now claims that the title for the https://github.com/patrick-kidger/jaxtyping front page is "Piotr Kaminski."

Hmmm, I think this is a Google bug. Searching for any of my GitHub repositories seems to grab something random off the front page of that site. I have no idea what's going on :D (Searching for e.g. github.com/google/jax seems to work though, grrr!)

avolchek commented 5 months ago

it seems like this doesn't work!

Yeah, I tried it too and it didn't work. As far as I remember jaxtyping hook actually doesn't do anything if beartype hook is used at the same time. So it doesn't add {push,pop}_shape_memo and shape "variables" will be shared throughout different functions in call stack

avolchek commented 5 months ago

# but this does not raise an error!

As far as I remember it's also because of lack of calls to memo-manipulation routines. You need to call 'push' at least once for shape variable checks to work.