patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.14k stars 59 forks source link

Incompatibility with latest typeguard version #80

Open dfuchsgruber opened 1 year ago

dfuchsgruber commented 1 year ago

I was updating my typeguard version to the latest release on Github: https://github.com/agronholm/typeguard (4.0.0rc5), and now get conflicts when trying to shape annotate torch Tensors (the code was working fine with previous typeguard versions). In particular running the following simple snippet:

from typeguard import typechecked
from jaxtyping import jaxtyped, Float
from torch import Tensor

@jaxtyped
@typechecked
def foo(x: Float[Tensor, 'a b c']):
    return 1

Will throw this error:

  File "/nfs/homedirs/fuchsgru/graph-active-learning/foo.py", line 6, in <module>
    @typechecked
     ^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_decorators.py", line 213, in typechecked
    retval = instrument(target)
             ^^^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_decorators.py", line 54, in instrument
    instrumentor.visit(module_ast)
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_transformer.py", line 561, in visit_Module
    self.generic_visit(node)
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/ast.py", line 494, in generic_visit
    value = self.visit(value)
            ^^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_transformer.py", line 672, in visit_FunctionDef
    annotation = self._convert_annotation(deepcopy(arg.annotation))
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_transformer.py", line 546, in _convert_annotation
    new_annotation = cast(expr, AnnotationTransformer(self).visit(annotation))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_transformer.py", line 339, in visit
    new_node = super().visit(node)
               ^^^^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_transformer.py", line 391, in visit_Subscript
    items = [self.visit(item) for item in slice_value.elts]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_transformer.py", line 391, in <listcomp>
    items = [self.visit(item) for item in slice_value.elts]
             ^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_transformer.py", line 339, in visit
    new_node = super().visit(node)
               ^^^^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/ast.py", line 418, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/site-packages/typeguard/_transformer.py", line 442, in visit_Constant
    expression = ast.parse(node.value, mode="eval")
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/staff-ssd/fuchsgru/miniconda3/envs/graph_active_learning/lib/python3.11/ast.py", line 50, in parse
    return compile(source, filename, mode, flags,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<unknown>", line 1
    a b c
      ^
SyntaxError: invalid syntax

Seems like typeguard is parsing the annotation "a b c" with ast, which does not like this format. Is there any workaround for this? I need the latest typeguard version as earlier versions do not seem to support annotating @property decorated functions.

Relevant versions of typeguard and `jaxtyping:

typeguard                4.0.0rc5.post1
jaxtyping                0.2.15
patrick-kidger commented 1 year ago

This looks to be a typeguard bug. Here's a repro without using jaxtyping at all:

from typing_extensions import Annotated
from typeguard import typechecked

@typechecked
def foo(x: Annotated[int, 'a b c']):
    return 1

I'd suggest raising this as an issue on the typeguard issue tracker!

dfuchsgruber commented 1 year ago

Wasn't sure if typeguard is supposed to be able to handle "arbitrary" annotations, as it seems to only support things that are parsable by ast. Anyway, I'm gonna open a typeguard issue to get response from the devs there, thanks :)

Conchylicultor commented 1 year ago

(oups, wrong issue)

patrick-kidger commented 1 year ago

FWIW I think this is now fixed with more recent typeguard releases, if anyone else would like to confirm.

johnryan465 commented 1 year ago

@patrick-kidger the issue has not been fixed but the issue has been closed on the typeguard repo (https://github.com/agronholm/typeguard/issues/353 for anyone else running into this). My solution has been to downgrade to a pre 4.0 version of typeguard for the moment.