patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.18k stars 62 forks source link

Can we check statement-level annotations? #7

Open awf opened 2 years ago

awf commented 2 years ago

One of my dreams for this package was to turn code like this

  query = linear(head.query, t1)                  # L x Dk
  key = linear(head.key, t1)                      # L x Dk

into this

  query : LxDk = linear(head.query, t1)
  key   : LxDk = linear(head.key, t1)

where we have written

  LxDk = f32["L Dk"] 

earlier in the @jaxtyped function.

But it looks as if these annotations aren't checked? I haven't looked into how hard that might be - is it a lot of work?

patrick-kidger commented 2 years ago

So checking intermediate annotations like this is really the job of a runtime type checker (typeguard/beartype) -- just like how checking the argument/return annotations is already the job of a runtime type checker. In either case all jaxtyping does is provide isinstance-compatible JAX types. After all, you may equally well wish to check x: int = foo(), which is an operation unrelated to JAX types.

Probably what would be needed here would be a decorator that parses the abstract syntax tree for a function, detects annotations, and then inserts manual isinstance checks such as assert isinstance(x, int). This probably qualifies as "not too bad" for someone already familiar with AST rewriting; probably something like

import ast
import inspect
import beartype
import jaxtyping

def check_intermediate_annotations(fn):
  ast = ast.parse(inspect.getsource(fn))
  # rewrite ast using ast.NodeVisitor and ast.NodeTransformer
  return eval(ast.unparse(ast))

@jaxtyping.jaxtyped
@beartype.beartype
@check_intermediate_annotations
def bar(y):
  x: LxDk = foo()
  return x

Although I note that this kind of source-file parsing and unparsing is a bit fraught with edge cases (e.g. it won't work in a REPL). Unfortunately Python just doesn't provide a good way to handle this kind of thing.

Off the top of my head I don't know of a runtime type checker that does this, though. (CC @leycec for interest.)

leycec commented 2 years ago

Fascinating feature request intensifies. Thanks so much for pinging me into the fray, @patrick-kidger. Exactly as you suspect, no actively maintained runtime type-checker that I know of currently performs static type-checking at runtime. Sad cat emoji is sad. :crying_cat_face:

That said, @beartype does have an open feature request encouraging us to eventually do this. This is fun stuff, because it's hard stuff. Actually, it's not too hard to naively perform static type-checking at runtime by combining the fearsome powers of import hooks + abstract syntax tree (AST) inspection. But it's really hard to do so without destroying runtime performance – especially in pure Python. Extremely aggressive on-disk caching (e.g., like the sordid pile of JSON files that mypy dumps into project-specific .mypy_cache/ subdirectories) would certainly be a hard prerequisite.

I Don't Like What I'm Hearing

Until then, @beartype provides a reasonably well-documented procedural API for type-checking arbitrary things against arbitrary type hints at any time:

query : LxDk = linear(head.query, t1)
key : LxDk = linear(head.key, t1)

# Runtime type-check everything above.
from beartype.abby import die_if_unbearable
die_if_unbearable(query, LxDk)
die_if_unbearable(key, LxDk)

If that's a bit too much egregious boilerplate, consider wrapping the above calls to linear() with a @beartype-friendly factory: e.g.,

from beartype.abby import die_if_unbearable

def linear_typed(*args, **kwargs, hint: object):
    '''
    Linear JAX array runtime type-checked by the passed type hint.
    '''

    linear_array = linear(*args, **kwargs)
    die_if_unbearable(linear_array, hint)
    return linear_array

query : LxDk = linear_typed(head.query, t1, LxDk)
key : LxDk = linear_typed(head.key, t1, LxDk)

Technically, that still violates DRY a bit by duplicating the LxDk type hints. Pragmatically, that's probably as concise as we can manage... for the moment.

Let's pray I actually do something and make static runtime type-checking happen, everybody. :face_exhaling:

awf commented 2 years ago

Thanks @leycec ! Just to clarify on "destroying runtime performance", do you mean making it worse than

query : LxDk = linear(head.query, t1)
die_if_unbearable(query, LxDk)
key : LxDk = linear(head.key, t1)
die_if_unbearable(key, LxDk)

?

awf commented 2 years ago

Because the above seems acceptable to me, particularly under a jax-style define-by-run scheme.

And if it's not I might wrap die_if_unbearable to have logic like:

die_if_unbearable(query, LxDk) if rand() > 0.93

or

die_if_unbearable(query, LxDk) if (beartype_time/total_program_time < bearable_beartype_overhead or
                                   rand() > (beartype_time/total_program_time / bearable_beartype_overhead))

[I realise there are simplifications/corrections of the last logic, hope the sentiment is clear]

patrick-kidger commented 2 years ago

Right, so JAX sits in a lovely spot for applicability of runtime type checking, because Python is only ever being used as a metaprogramming language for XLA.

In this context I wouldn't worry about the extra runtime overhead.

leycec commented 2 years ago

Agh! I should be more explicit in my jargon, especially when slinging around suspicious phrases like "destroying runtime performance." So...

Just to clarify on "destroying runtime performance", do you mean making it worse than

query : LxDk = linear(head.query, t1)
die_if_unbearable(query, LxDk)
key : LxDk = linear(head.key, t1)
die_if_unbearable(key, LxDk)

Yes. Much, much worse than that. die_if_unbearable() does not destroy runtime performance, because all @beartype operations to date – including die_if_unbearable() and @beartype-decorated wrapper functions alike – exhibit constant-time O(1) runtime performance with negligible constant factors. It don't get faster than that, right? Literally.

By "destroy runtime performance," I was instead referring to the hypothetical crippling performance burden of doing static type-checking analysis at runtime via import hooks and AST inspection. @beartype don't do that yet; nobody does. The practical difficulties of optimizing static analysis at runtime is a Big Reason™ why.

But... someday @beartype or somebody else will go there. A runtime type-checker that efficiently performs static analysis at runtime would effectively obsolete standard static type-checkers (e.g., mypy, pyright) for most practical purposes. Since there is both Big Money™ and Big Hype™ for building that field of dreams, it will happen... someday.

Until then, we collectively wish upon a rainbow. :rainbow:

Because the above seems acceptable to me, particularly under a jax-style define-by-run scheme.

Absolutely. @beartype has been profiled to be disgustingly fast. That's the whole point, really. @beartype is actually two orders of magnitude faster than even pydantic, which is compiled down to C via Cython. Yeah. We're that fast.

You should never need to conditionally disable @beartype. If you do, bang on our issue tracker and we'll promptly resolve the performance regressions you are seeing. Until then, the best way to use @beartype is to just always use @beartype.

And if it's not I might wrap die_if_unbearable to have logic like:

die_if_unbearable(query, LxDk) if rand() > 0.93

...heh. Probabilistic runtime type-checking. Love it! I must acknowledge cleverness when I see it. Admittedly, that also makes my eye twitch spasmodically.

If you do end up profiling die_if_unbearable() for your particular use case, please post your timings. That function is a bit less optimized than it could be – mostly as I didn't realize there was actual demand for statement-level runtime type-checking.

Now I know. And knowledge is half the battle.

In this context I wouldn't worry about the extra runtime overhead.

These are sweet, soothing words. Please say more relieving things like this. :relieved:

awf commented 2 years ago

Thanks both for this discussion - I implemented something quick based on @patrick-kidger's spike above, and it seems to work quite nicely. Next step is to integrate with jaxtyping, but I thought I would put it out here...

https://github.com/awf/awfutils#typecheck

A fairly direct copy of your suggestion above...

https://github.com/awf/awfutils/blob/7359acb6528325f6770fc9c28aab86f548d22ad4/typecheck.py#L133

leycec commented 2 years ago

Extremely impressive. @awfutils.typecheck is the first practical attempt I've seen at performing static type-checking at runtime. Take my thunderous clapping! :clap: :clap: :clap:

Your current approach is outrageously useful, but appears to currently only support isinstance()-able classes rather than PEP-compliant type hints: e.g.,

# I suspect this fails hard, but am lazy and thus did not test.
@typecheck
def foo(x : List[int], y : int):
  z : List[int] = x * y
  w : float = z[0] * 3.2
  return w

foo([3, 2, 1], 1.3)

Is that right? If so, that's still impressive tech for a several hundred-line decorator. I'll open up a feature request on your issue tracker to see if we can't trivially generalize that to support all (...or at least most) PEP-compliant type hints, @awf.

In short, this is so good. \o/

patrick-kidger commented 1 year ago

Not really documented yet, but for anyone coming across this issue: this now exists in beartype (https://github.com/beartype/beartype/issues/7#issuecomment-1646494470)!

leycec commented 1 year ago

Indeed. As @patrick-kidger notes, our new beartype.claw API transforms @beartype into a hybrid runtime-static type-checker. This is the way:

# In your top-level "{your_package}.__init__" submodule:
from beartype.claw import beartype_this_package
beartype_this_package()

That's it. @beartype will now type-check statement-level annotations in concert with jaxtyping.

Not really documented it yet...

...yeah. Noticed that, huh? I've intentionally left beartype.claw undocumented for a bit. Technically, it's rock solid as is and "good enough" for most use cases and production workloads. Still, first impressions are everything; it'll really benefit from stability improvements in our upcoming @beartype 0.16.0 release – especially with respect to complex forward references (e.g., 'Dict[str, MuhGeneric[int]]') and PEP 563 (i.e., from __future__ import annotations). Thankfully, I've almost finalized @beartype 0.16.0 and expect it to land in a week or two.

Until then, one-liners for great QA justice! :muscle: :bear: