Open jeezrick opened 8 months ago
So right now this will work if you use a manual isinstance
check:
@jaxtyped(typechecker=beartype)
def forward(self, x: Patch_embed) -> Patch_embed:
x: Mlp_mid = self.act(self.lin1(x))
assert isinstance(x, Mlp_mid)
return self.lin2(x)
but I agree that inserting this automatically would be an awesome feature to have.
I can see two possible ways we might add this:
@jaxtyped
actually grab the source code and parse+rewrite the AST (or some other intermediate representation) for the function. Pro: this would add this feature everywhere that jaxtyped
is used. Con: this would be exceptionally magic and might break for various interesting edge cases (e.g. generated functions, which don't actually have a source code representation).install_import_hook
also rewrite the AST for these type annotations. Pro: this is known to work as beartype takes this approach. Con: it would only work for the import hook, and not for manual jaxtyped
decorators.Tagging @leycec -- I know you've tackled this problem before; do you have any thoughts on the matter?
BTW, I mainly use it when I am trying to understand others code. But can I include it in production? How much does it slow down the training and inference?
If you're using JAX, then it shouldn't slow runtime at all. There will only be a negligible amount of work at compile time, just to actually check that the shapes and dtypes of the arrays are what they say there are.
If you're using PyTorch/etc. then the amount of overhead hasn't really been benchmarked, so I'm less sure :)
...heh. The age-old PEP 526-compliant Annotated Variable Assignment Runtime Type-checking Problem, huh? That one just never gets old. So, thanks for pinging me on! I love this sort of hacky Python wrangling.
Sadly, you already know everything. This is like when Luke in Empire Strikes Back finally realizes that the little wizened green lizard blob thing is actually the most powerful surviving Jedi in the entire universe. You are that blob thing. Wait. That... no longer sounds complimentary. Let's awkwardly start over. </ahem>
As you astutely surmise, two sucky roads lie before you:
typeguard
does. That is, supercharge @jaxtyped
to instrument ASTs. typeguard
code can be a little arduous to reverse-engineer, owing to a general lack of internal commentary in that codebase. From the little to nothing that I've gleaned, the @typeguard.typechecked
decorator attempts to do this. I italicize attempts, because I remain unconvinced that that genuinely works in the general case. In-memory classes and callables will be the test failures of us all. @beartype briefly considered doing this and then quickly thought better of it. A bridge too far is a bridge that will explode, plummeting all of us to our shrieking dooms below. also, i have no idea how to do thisinstall_import_hook
with a new visit_AnnAssign()
method. This is the least magical and thus best possible approach.Relatedly, you almost certainly want to steal everything not nailed down be inspired by @beartype's own visit_AnnAssign()
implementation. The proliferation of edge cases make the implementation non-trivial. Because @beartype and the Power of Asperger's, the code is outrageously commented. You can probably streamline that by... alot.
Tangentially, would you like @beartype to automate your import hooks for you?
@beartype would be delighted to silently piggyback jaxtyping
onto its own hunched back. @beartype already provides a rich constellation of import hook primitives (e.g., beartype_this_package()
, beartype_packages()
, beartyping()
). Moreover, anyone who wants @beartype probably also wants jaxtyping
. Moreover, it's optimally efficient for @beartype to circumvent all of the file I/O and CPU churn associated with the jaxtyping.install_import_hook()
import hook by instead just directly applying your currently private (and thus horribly risky) jaxtyping._import_hook.JaxtypingTransformer
subclass where @beartype applies its own beartype.claw._ast.clawastmain.BeartypeNodeTransformer
subclass.
Specifically, @beartype would automate away everything by:
jaxtyping
is importable. Trivial.jaxtyping
is importable:
jaxtyping._import_hook.JaxtypingTransformer
subclass immediately after or before (...not sure which) applying its own beartype.claw._ast.clawastmain.BeartypeNodeTransformer
subclass. Trivial. If I'm reading your code correctly, @beartype can just do something like:# In our private "beartype.claw._importlib._clawimpload` submodule:
...
# Attempt to...
try:
# Defer optional dependency imports.
from jaxtyping._import_hook import JaxtypingTransformer
# AST transformer decorating typed callables and classes by "jaxtyping".
#
# Note that we intentionally pass *NO* third-party "typechecker" to avoid
# redundantly applying the @beartype.beartype decorator to the same callables
# and classes twice.
ast_jaxtyper = JaxtypingTransformer(typechecker=None)
# Abstract syntax tree (AST) modified by this transformer.
module_ast = ast_jaxtyper.visit(module_ast)
# If "jaxtyping" is currently unimportable, silently pretend everything is well.
except ImportError:
pass
# AST transformer decorating typed callables and classes by @beartype.
ast_beartyper = BeartypeNodeTransformer(
conf_beartype=self._module_conf_beartype)
# Abstract syntax tree (AST) modified by this transformer.
module_ast_beartyped = ast_beartyper.visit(module_ast)
Of course, you don't need to actually depend upon or require @beartype in any way. No changes on needed on your end. Actually, one little change would improve sanity: if you wouldn't mind publicizing JaxtypingTransformer
somewhere, @beartype can then just publicly import that and no longer worry about destroying itself.
That's it. Super-easy, honestly. Everyone currently getting @beartype would then get jaxtyping
as well for free. Free is good. Let us give the good people what they want, @patrick-kidger. Smile, benevolent cat! Smile! :smile_cat:
Oh – and @beartype now officially supports Equinox. 2024: dis goin' be gud.
Thank you @leycec! Okay, looks like we're stuck with these two options. I'll have a think about what implementing these looks like. I'm not wild about either approach, they're both pretty magic... :D The beartype implementation in particular is a very useful reference.
As for having the AST transformers / import hooks work together... hmm, I think I'd need to understand this better still. I think right now what we've got aren't really directly compatible with each other. We might need some more changes elsewhere before this is doable.
To explain: jaxtyping doesn't really want you to use JaxtypingTransformer(typechecker=None)
. This is the old way of doing things.
As of fairly recently, we now do some fairly evil things under the hood, corresponding to passing typechecker=beartype.beartype
. This terrifying block of code will transform
@jaxtyped(typechecker=beartype)
def foo(x: Ann) -> Ret:
... # do stuff
out = foo(bar)
into
@beartype
def foo_args(x: Ann):
pass
try:
foo_args(bar)
except Exception as e:
raise ValueError("A helpful error message for the arguments") from e
out = foo(bar)
@beartype
def foo_ret(x: Ann) -> Ret:
return out
try:
foo_ret(bar)
except Exception as e:
raise ValueError("A helpful error message for the return value") from e
where foo_args
and foo_ret
are actually functions dynamically generated at decoration time. Fun.
The reason for this is so that we can finally give those nice error messages about shapes and dtypes and what-not when something goes wrong. What this means is that if jaxtyping is being used, it's basically assuming that it's "in charge": it gets to decide how to report error messages, not beartype! (Although you can scroll up a bit in the traceback to see the underlying error message that was caught and attached as __cause__
.)
This is actually pretty nice from the jaxtyping point-of-view. We don't have to worry about whether you're using beartype or typeguard or anything else: if the decorator raises an exception, we can use it.
So, what does this mean at this point? Frankly, I'm not 100% sure. :D It's definitely not optimal for jaxtyping to have to reimplement the magic required for checking statement annotations. Perhaps we could factor out all the import hook business into a shared third library? Possibly that leads to fire and explosions.
Equinox: hurrah! That's awesome. What an excellent Christmas gift.
Woooooooah. Indeed, I see terrifying – yet ultimately justifiable – shenanigans that sadden me. In the absence of a standardized plugin system for runtime-static type-checkers generically supported by both @beartype and typeguard
, though, what else you gonna do? Right? So you force the issue by just doing everything yourself. Still, I feel a gnawing fear. That's some pretty gnarly AST munging just to circumvent the lack of a sane plugin API in other people's code.
@beartype and typeguard
alike are clearly failing jaxtyping
here. But... it's not just jaxtyping
we're failing. It's any third-party package publishing custom type hint factories like jaxtyping.Float[...]
, really.
My only wish for 2024 is to stop failing jaxtyping
. That, and for our Maine Coons to finally cough up their pernicious hairballs. Just give it up already, cats! :cat2:
The jaxtyping
approach seems to conflict pretty hard with the @beartype approach. Both packages want be to in the driver's seat. Both packages want to apply their own AST transformations. If @beartype doesn't get to apply its AST transformations, then @beartype literally cannot support a growing laundry list of PEP standards. The @beartype.beartype
decorator will then raise exceptions when presented with type hints conforming to those standards. This includes:
type
alias statements. Sadly, type
aliases can't be properly type-checked at runtime without transforming the AST containing those statements. It is sucky.typeguard
is even more intensely dependent on AST transformation. Far more intensely. Like... have you seen what agronholm is up to over there? He's begun refactoring typeguard
hard-core into a full-featured static type-checker that runs at runtime via extremely non-trivial AST transformations. The old-school @typeguard.typechecked
decorator doesn't seem to do too much anymore – not by compare to the typeguard
AST transformation, anyway. Basically, typeguard
now assumes usage of its import hook to deliver typeguard
-instrumented ASTs across entire packages.
@beartype is no different. We're just taking a longer and more circuitous route to get to the same place. The above list will definitely grow over time. Indeed, the mere existence of the above list suggests that nobody should call type-checking decorators in 2024. They're obsolete. They're insufficient. And they'll probably be deprecated by both @beartype and typeguard
at some point in the near future. Probably not this year – but probably sooner than is comfortable for anybody.
So. I grok the @jaxtyping.jaxtyped
approach. That totally makes sense for your use case of delivering human-readable type-checking violations. But you'll still need some sort of ast.NodeTransformer
subclass to type-check (...waitforit) PEP 526-compliant annotated variable assignments. Whatever subclass that is, that subclass will probably also inject AST nodes decorating classes and callables with @jaxtyping.jaxtyped
, right?
So. @beartype's own import hooks can still detect and apply your ast.NodeTransformer
subclass after applying its own ast.NodeTransformer
subclass. In this case, @beartype would probably want to stop decorating classes and callables with @beartype.beartype
and let @jaxtyping.jaxtyped
handle that. Of course, even that approach has problematic hot spots: namely, configuration. @beartype's AST transformer would need some way of passing on the optional keyword-only @beartype.beartype(conf: BeartypeConf)
parameter through jaxtyping
's AST transformer and into @beartype itself. Also, @beartype's AST transformer preferentially decorates classes rather than callables. "Ugh!", I say.
All of this would be a whole lot easier if @beartype just hurried up already and provided a public API for jaxtyping
to generate human-readable type-checking violations from jaxtyping
type hints. In that case, all that insanity cleverness surrounding @jaxtyping.jaxtyped
would go away (with respect to @beartype, anyway).
Ultimately, @beartype, typeguard
, and jaxtyping
all need to "just get along" by supporting some form of ast.NodeTransformer
composition. You might get me to agree to a lower-level shared third library responsible for registering type-checking import hooks, but you will never get agronholm to agree to that. Dude definitely walks his own way. Nothing wrong with that, either. I respect agronholm immensely. Stubbornness solves all problems.
There's much to ponder. Yet, the will to code big ol' plugin architectures is weak. :face_exhaling:
One crazy idea would be to just fold portions of jaxtyping
into @beartype. I'd happily grant you push access to @beartype and free reign to inject whatever JAX-specific shambolic typing horrors wonders of joy you like into the main codebase. I trust you implicitly. Under this governance model:
jaxtyping
itself would still be responsible for publishing and maintaining JAX type hints. However...Consider it! Could be fun. Or... it could be a living Hell. :smiling_imp: :open_mouth:
Wait. I'd still love to add you as a collaborator to @beartype, @patrick-kidger – but I've realized the painfully obvious while cross-country skiing the muddy and rock-strewn trails of backwoods Ontario. With great snow comes great enlightenment.
@beartype 0.17.0 will support the plugin API that jaxtyping
needs, wants, and deserves already. Thankfully, this is trivial for both us and you. Whenever an object fails a type-check against a normal class, @beartype will just look for a def __instancecheck_str__(cls, obj: typing.Any) -> str:
dunder method on the metaclass of that class; if found, @beartype will call that method to generate a human-readable exception message unique to that class. Usage parallels the existing def __instancecheck__(cls, obj: typing.Any) -> bool:
dunder method that you're almost certainly using everywhere.
When @beartype 0.17.0 does this:
jaxtyping
is the best.@jaxtyped
decorator is currently doing to the @beartype
decorator. I'm all for perfidious evil. Don't get me wrong. As a recuperating metalhead, Ley "Perfidious Evil" Cec is surely my true name. That said, this particular perfidious evil should no longer be needed. Which is good, as not doing that will also trivially resolve #92... and heaps of other outstanding issues that have probably gone unreported, honestly.In short, jaxtyping
is the best.
Okay, lots of interesting ideas here! Settle in, I have a wall of text of my own.
I think we all agree that the ideal future would be for jaxtyping only to provide:
(a) the type hints;
(b) the @jaxtyped
decorator around a function (providing the dynamic context within which the shape-checks are performed)
and leave all type-checking and error-reporting to beartype/typeguard.
It's great that the custom error message reporting is now available in beartype! In that case, I think I see the same easy-peasy future for us as you do. I'll implement the __instancecheck_str__
method in jaxtyping, and then from that point the old-style double-decorator
@jaxtyped(typechecker=None)
@beartype
def foo(...)
approach to things should just work.
hopefully
As a practical matter, I expect we should be able to arrange to add both decorators using two separate import hooks.
In particular neither package is in the driver's seat. We're just both adding our own decorators via import hooks.
Actually, I realise beartype will be doing something slightly different to just adding a decorator -- you'll be adding isinstance
checks elsewhere as per PEP 526, and might even do that for the arguments and return values too. jaxtyping don't care! All that matters is that all isinstance
checks happen within the dynamic context of the jaxtyped
decorator; that's the only thing we need to be sure to make happen.
(a) This means that we're not coupling beartype and jaxtyping together in any meaningful way. The only contact point is the __instancecheck_str__
method. This is good! Coupled code is a source of nightmares. (But thank you for the offer of push access to beartype!)
(b) As a practical matter, jaxtyping+v2-of-typeguard is actually a very popular combination, and this approach means that we won't be risking breaking that either. (For the sake of such approaches, I'm afraid I'll still keep the perfidious evil of jaxtyped(typechecker=...)
, even if it won't be the recommend path when used with beartype.)
(c) It fixes up #92, as you note!
(d) It makes it possible to use jaxtyping with PEP 526. The original purpose of this issue, lest we forget... :D
Is naming the method __instancecheck_str__
going to be robust to future changes in Python? What we're doing here is not such an unusual thing to want, so something like Python 3.18 might add this as a method then. Could it be called __beartype_instancecheck_str__
for additional safety?
I'm finding the current beartype import hooks a bit... intense. There appear to be:
beartype_this_package
, beartype_package
, beartype_packages
, beartyping
, beartype_all
);beartype_package
and beartype_packages
basically overlap;beartype_package(s)
over beartype_this_package
?beartype_all
is probably too dangerous to ever be used;beartyping
isn't actually documented besides the one example;beartyping
doesn't appear to take a package argument name. Is the idea that it grabs the __module__
from higher up the stack? (I guarantee I can break that with a delayed wrapper.)I'm hoping for a future in which we can recommend something consistent like:
# foo/__init__.py
with jaxtyping.install_import_hook("foo"), beartype.install_import_hook("foo"):
from . import bar
from . import baz
or
# foo/__init__.py
jaxtyping.jaxtype_this_package()
beartype.beartype_this_packge()
from . import bar
from . import baz
Likewise it'd be pretty neat if beartype had a pytest hook or an IPython extension.
(You can tell I like the selection of largely-orthogonal hook options that jaxtyping provides.)
Entirely unrelated, but whilst I have you here: what's the status of O(n)
checking in beartype? At least for JAX this is definitely the choice we want, as all checking happens just once at trace time and never at runtime, so the overhead isn't important.
Excellence! beartype/beartype@6b3aadfff7f9e4ef1ccde is the first step on this tumultuous voyage into the unknown. Your questions are, of course, apropos and almost certainly highlight deficiencies in my worldview. To wit:
Is naming the method
__instancecheck_str__()
going to be robust to future changes in Python?
...heh. Let's pretend it is. Actually, the hope is that this will eventually metastasize into an actual PEP standard. In the meanwhile, I hope that somebody who is not me will market this to agronholm himself at the typeguard
issue tracker. Ideally, all runtime type-checkers should transparently support __instancecheck_str__()
. Right? It's just the right thing to do.
Injecting the suspiciously @beartype-specific substring "__beartype__"
into this API would probably inhibit that. Likewise, @beartype would probably never support something named __typeguard_instancecheck_str__()
. It's a fragile human ego thing, I think. :sweat_smile:
I'm finding the current beartype import hooks a bit... intense.
(╯°□°)╯︵ ┻━┻
5 different options
you're not wrong
beartype_package
andbeartype_packages
basically overlap;
you're not wrong
beartype_all
is probably too dangerous to ever be used;
you're not wrong
beartyping
isn't actually documented besides the one example;
you're not wrong
Wait... I'm beginning to detect a deterministically repeatable pattern here. Allow me to now try but fail to explain:
beartype_this_package()
and beartype_packages()
. But everybody else who piled onto the pre-release beta test of the beartype.claw
subpackage wanted all of those other things. Things got out of hand quickly. This is what you happens when you do what other people want. Now, @beartype can't back out of any of those decisions without breaking backward compatibility across the Python ecosystem. Thankfully, it's mostly fine. These functions have yet to actually break or do anything bad. So, there's no maintenance burden or technical debt here. They just kinda hang out, doing their own thing. I salute these functions I do not need and never really wanted.beartyping
. But nobody complained, I got tired, and then video games happened.beartyping
just contextually applies to everything in the body of its with beartyping(...):
block. I didn't realize that relative imports might break that. I never even considered relative imports anywhere. I always do absolute imports everywhere, because I am obsessive-compulsive and have trust issues. Do relative imports break beartyping
? I... don't know, actually. I should probably go and test that. But... I probably won't until somebody complains. I'm still tired and video games are still happening.Assuming beartyping
actually does what it says it does, this presumably works: totally doesn't
# foo/__init__.py
with jaxtyping.install_import_hook("foo"), beartype.beartyping():
from . import bar
from . import baz
This definitely should work, too:
# foo/__init__.py
jaxtyping.jaxtype_this_package() # <-- dis iz h0t
beartype.beartype_this_package()
from . import bar
from . import baz
Likewise it'd be pretty neat if beartype had a pytest hook...
Yes! So much, "Yes!" I actually implemented a pytest
plugin a month ago. But then I got bored, never tested anything, and never stuffed that into an actual GitHub repository. @beartype even has an existing pytest-beartype
repository – but it's still empty. Somebody else was supposed to do all that. But then the holidays and presumably video games happened.
Long story short: "Nobody did nuffin'." :face_exhaling:
...what's the status of
O(n)
checking in beartype?
Given that @beartype still fails to deeply type-check most standard container types like dict
and set
in O(1)
time, let's just quietly accept that O(n)
type-checking is a dream beyond the veil of sleep.
That said, does this actually intersect with jaxtyping
? From @beartype's perspective, isn't a JAX array just this monolithic object that @beartype farms off to jaxtyping
?
Oh. Wait. I never actually implemented support for jaxtyping
in @beartype. Welp, that's awkward. Basically, @beartype should be doing the same thing that @beartype currently does for the third-party Pandera type hint package: namely, @beartype should silently ignore all type hints from jaxtyping
. Is that right?
Previously, I'd assumed that jaxtyping
was using the standard __instancecheck__()
mechanism for its type hints and that @beartype didn't have to worry about anything. But that definitely doesn't seem to be the case. Is that right? Should @beartype not be type-checking objects annotated with jaxtyping
type hints via the standard isinstance()
builtin?
I'm not even necessarily clear what "trace time" is, frankly. My wife and I are currently sloooowly migrating our data science pipeline from Ye Ol' Mostly Single-threaded NumPy and SciPy World to Speedy Gonzalez JAX World. I must confess that I am dumb, in short.
On naming: hmm, I don't love it but I see the point that this may help adoption.
On the number of import hooks: yeah, I'm not suggesting removing these, what with backward compatibility. I just don't love how complicated the documentation is with all these options. What I've done for things like this before is to just remove the less-favoured options from the documentation whilst keeping them around in the code. The goal being to keep a well-lit path for new users.
On relative imports: I suspect relative imports will work with beartyping
just fine. (Personally I use them everywhere: I much prefer the style of relative imports from within a library, and absolute imports to pull in other libraries. Has some nice benefits like not making assumptions about the PYTHONPATH
etc.)
On beartyping
(round two): what you've got seems to break transitive imports. That is, if one of those libraries happens to import e.g. scipy, then beartype will attempt to decorate all of scipy too. (This is why jaxtyping.install_import_hook
takes a package name.)
On pytest + O(n)
checking: fair enough! Never enough time in an open-source maintainer's life to implement everything. I'll ~bug~ ask you about these some other time.
beartype already does the correct thing for jaxtyping! Just call isinstance
on the type hint like any other custom class. This is what jaxtyping is designed for: just to provide hints for use as isinstance(some_array, Float[Array, "foo bar"])
.
To clear up what "trace time" means -- this refers to actually running the Python code for a JAX program. JAX uses JIT-compiler, so as it runs through the code it records every array operation that occurs. At that point it's built up a DAG of all the array-based operations, compiles them... and never runs the Python code again. The compiled version is what is used at runtime. This makes runtime type checking and JAX be a great fit for each other, as we only pay the cost of runtime type checking just once during this tracing, and not at all at "actual" runtime.
Oh, fascinating. I just love me some JIT + DAG action. It looks like @beartype is in good paws here.
On
beartyping
(round two): what you've got seems to break transitive imports.
...heh. I was wondering when you'd catch that. This is actually intentional, because the current implementation of beartyping()
trivially reduces to:
from contextlib import contextmanager
@contextmanager
def beartyping(conf: BeartypeConf = BeartypeConf()) -> Iterator[None]:
try:
beartype_all(conf=conf)
yield
finally:
undo_beartype_all()
That's... it. You are thinking:
beartype_all()
.I used to think similarly up until five minutes ago. Then I slipped off a toilet while hanging a clock shaped like a hibernating bear, banged my head on a towel rack shaped like a spawning salmon, and... I saw it. The Flux Beartyper:
# In your "{your_package}.__init__":
from beartype import BeartypeConf
from beartype.claw import beartype_this_package, beartype_all
beartype_this_package()
beartype_all(conf=BeartypeConf(
violation_param_type=UserWarning,
violation_return_type=UserWarning,
))
I've never documented this, but @beartype import hooks are actually backed by a pure-Python trie data structure mapping from package names to @beartype configurations. That's sorta why the whole thing took so long for me to implement. In the above example that may very well become the lead goto example on @beartype's front page, we compel @beartype to:
The Flux Beartyper is thus the superset of mypy and pyright
: it does everything those guys do (complain about everything), while also doing something those guys can never do (actually enforce something). Moreover, it selectively enforces those things on the one thing you have under your total control: your own package.
You are thinking:
Tough crowd! Oh – and I've fully exercised __instancecheck_str__()
with unit tests. All's well that passes. This is your final API for raising human-readable violation messages:
from beartype import beartype
class MetaclassOfMuhClass(type):
def __instancecheck_str__(cls, obj: object) -> str:
return (
f'{repr(obj)} has disappointed {repr(cls)}... '
f'for the last time.'
)
class MuhClass(object, metaclass=MetaclassOfMuhClass):
pass
@beartype
def muh_func(muh_obj: MuhClass) -> None:
pass
muh_func("Some strings, you just can't reach.")
...which now raises the "human"-readable violation message:
beartype.roar.BeartypeCallHintParamViolation: Function
__main__.muh_func() parameter muh_obj="Some strings, you just can't
reach." violates type hint <class '__main__.MuhClass'>, as "Some
strings, you just can't reach." has disappointed <class
'__main__.MuhClass'>... for the last time.
Please think of a better name that Pydantic and typeguard
will both accept. I'd love to satisfy everybody's wildest API passions. Is your main objection forward compatibility concerns with future Python releases? Totally valid. In theory, the "_"
character between the instancecheck
and str
should guard a bit against that; all official dunder methods contain no internal "_"
characters, so far as I know. I shrug noncommittally.
Thanks again for your painstaking interest in all this. It's a delight to dissect API design with a fellow gearhead.
@beartype 0.17.0 should drop this Friday or thereabouts. Let us light a candle in our hearts for AI QA. :candle: :robot: :candle:
Hi!
Just wanted to share some information about this issue. After __instancecheck_str__
support was added to both jaxtyping
and beartype
, I was able to quite easily monkey-patch beartype
's import hook so that it uses jaxtyping
's memo-manipulation routines. I then used beartype
's hook instead of jaxtyping
's, and it seems that it works correctly and checks local variable assignments like in the OP's post.
https://gist.github.com/avolchek/6ac328c435e2bd584c2722ce058de824
@avolchek: That's hilarious, horrifying, and hellish: The Three Holy H's. That's also guaranteed to break, because:
beartype._decor._decornontype.generate_code
, which isn't even the actual submodule defining that function; you just happened to get accidentally lucky with a fragile import, which is kinda funny but also deeply concerning). I often refactor the entire @beartype codebase, because I am a pedantic code masochist and enjoy public displays of self-flagellation.jaxtyping
privacy encapsulation in dangerous ways that are exploding as we speak (e.g., jaxtyping._storage.push_shape_memo({})
). My Gods! How much privacy can one monkey-patch violate!? The answer may shock you, because I'm looking at it right here. Dynamite is the gist of this gist.Seriously, though. That's gonna break. In fact, that probably already broke. I'm on the cusp of releasing @beartype 0.18.0. GitHub only knows where jaxtyping
is going. Possibly, they don't need roads there.
I have now excoriated you, @avolchek. Now allow me to praise you! Yes, praise! Because that's actually an ingenious reverse engineering of two excruciatingly non-trivial codebases. Thanks to your solemn sacrifice and soon-to-be-broken codebase, I can happily accept that integrating @beartype + jaxtyping
is actually trivial. All @beartype needs to do is:
jaxtyping
type hints. Probably trivial. If @beartype can detect Pandera type hints (i.e., The Ultimate Pain (TUP)), @beartype can detect literally anything.jaxtyping
type hints, dynamically inject jaxtyping
setup and teardown logic into the type-checking wrapper function that @beartype generates as follows:import jaxtyping
def __beartype_wrapper(...):
jaxtyping._storage.push_shape_memo({})
... # <-- *BEARTYPE MAGIC HAPPENS HERE*
__beartype_pith_0 = ... # <-- *MORE MAGICAL UNICORNS ERUPT*
jaxtyping._storage.pop_shape_memo()
return __beartype_pith_0
Clearly, that's trivial. Clearly, @patrick-kidger is also grinding his teeth into stubs. Don't worry! I won't do anything without your consent. For one, I'm lazy. For another, I'd have to violate privacy encapsulation. For a final one, there's no guarantee any of this dark magic will continue to behave itself in perpetuity without everyone's explicit consent and continual agreement.
How do you feel about this sort of horror, @patrick-kidger? I know you. You're like me – only stronger, fitter, and more likely to survive the collapse of Canada's rickety maple syrup market. You really want to remain "in the driver's seat." But @beartype can probably solve all your problems with just a trivial amount of integration, glue, sputum, white-knuckle grips on the keyboard, and eyes-wide-shut commits guaranteed to blow up. Should @beartype make overtures towards doing something like this or should I just back away slowly from the keyboard before anybody gets hurt and feels bad?
@patrick-kidger: Unrelatedly, Google now claims that the title for the https://github.com/patrick-kidger/jaxtyping
front page is "Piotr Kaminski." When I then google the name "Piotr Kaminski," I see a concerning New York lawyer manbro who can only charitably be referred to as an "American Psycho"-era Christian Bale stunt double. I don't know what's happening. Probably, something is happening. My search for the truth bottomed out here. :laughing:
That's also guaranteed to break
For sure :)
Detect jaxtyping type hints. Probably trivial. If https://github.com/beartype can detect Pandera type hints (i.e., The Ultimate Pain (TUP)), https://github.com/beartype can detect literally anything.
I'm not sure that you need to add jaxtyping
stuff into the type-checking wrapper only if you have jaxtyping
type-hints. E. g. in
def foo():
kek: Float32[torch.Tensor, "a b c"] = ...
def bar(x: Float32[torch.Tensor, f"b w h c"]):
...
foo()
...
isinstance
check for kek
will "see" value of c
from x
shape. In contrast jaxtyping
import hook adds {push,pop}_shape_memo
to every function, regardless of whether it has jaxtying
hints or not
Haha, @avolchek, I'm impressed!
Although, with the latest jaxtyping and beartype releases, then in theory this should "just work", just by applying both of our import hooks.
Indeed I would like to update the documentation to reflect this, however... it seems like this doesn't work!
# entry_point.py
from beartype.claw import beartype_package
from jaxtyping import install_import_hook
beartype_package("foo")
with install_import_hook("foo", typechecker=None):
import foo
# foo.py
import jax
from jaxtyping import Array, Float
def foo():
x: Float[Array, "size"] = jax.numpy.ones(3)
y: Float[Array, "size"] = jax.numpy.ones(4) # this is an error!
foo() # but this does not raise an error!
I've tried both orders for installing the import hooks. I'm not sure what's going on here right now.
(Notably, however, something like x: Float[Array, ""] = jax.numpy.ones(3)
will fail, since it's independent of the jaxtyping context.)
I think if we're to direct our efforts, it should be to be sure that this approach is usable! This would allow both packages to do what they do best, without needing to interface with each other at all.
Google now claims that the title for the https://github.com/patrick-kidger/jaxtyping front page is "Piotr Kaminski."
Hmmm, I think this is a Google bug. Searching for any of my GitHub repositories seems to grab something random off the front page of that site. I have no idea what's going on :D (Searching for e.g. github.com/google/jax seems to work though, grrr!)
it seems like this doesn't work!
Yeah, I tried it too and it didn't work. As far as I remember jaxtyping hook actually doesn't do anything if beartype hook is used at the same time. So it doesn't add {push,pop}_shape_memo
and shape "variables" will be shared throughout different functions in call stack
# but this does not raise an error!
As far as I remember it's also because of lack of calls to memo-manipulation routines. You need to call 'push' at least once for shape variable checks to work.
Great project, helps me understand DL code a lot.
I used it like this:
but turns out, it doesn't do runtime type check on this
x: Mlp_mid = self.act(self.lin1(x))
line. And this makes me feel insecure. So, my question is, will this feature be added in the future? Or is it in confilct with some design intention?BTW, I mainly use it when I am trying to understand others code. But can I include it in production? How much does it slow down the training and inference?