patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.05k stars 49 forks source link

Issues with torch.compile #196

Open botev opened 3 months ago

botev commented 3 months ago

We are very happy with the fact that jaxtyping supports Pytorch as well, but we are currently hitting some kind of weird error/edge case and was hoping if you can give some suggestions. When compiling a module and trying to run it we get this stacktrace:

  File "/build/work/cfc8a89b76634373e85beb2a59a94e9e781a/google3/runfiles/google3/third_party/py/torch/_dynamo/bytecode_transformation.py", [line 646](https://cs.corp.google.com/piper///depot/google3/third_party/py/torch/_dynamo/bytecode_transformation.py?l=646&ws=botev/13260&snapshot=14397), in compute_exception_table
    keys_sorted = sorted(exn_dict.keys(), key=lambda t: (t[0], -t[1]))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.InternalTorchDynamoError: '<' not supported between instances of 'NoneType' and 'int'

from user code:
   File "/build/work/cfc8a89b76634373e85beb2a59a94e9e781a/google3/runfiles/google3/third_party/py/jaxtyping/_decorator.py", [line 411](https://cs.corp.google.com/piper///depot/google3/third_party/py/jaxtyping/_decorator.py?l=411&ws=botev/13260&snapshot=14397), in wrapped_fn
    bound = param_signature.bind(*args, **kwargs)
patrick-kidger commented 3 months ago

I'd suggest raising this with the PyTorch folks, including a MWE. This is likely this is an instance of hitting something torch.compile doesn't support yet. Another example came up recently at https://github.com/pytorch/pytorch/issues/122093

botev commented 3 months ago

Hmm, I managed to fix a few things and rearrange, but now I get:

The problem arose whilst typechecking parameter 'self'.
Actual value: MyModuel(....)
Expected type: <class 'inspect._empty'>.

which I would guess is because the compile rewrite the forward pass as a pure function?

patrick-kidger commented 3 months ago

I'm not sure of the details of torch.compile, so hard for me to speculate I'm afraid.

Chrixtar commented 2 months ago

Hi @botev , I face exactly the same (original) problem. Did you find any solution?

botev commented 2 months ago

Unfortunately no, I just disabled the guard for PyTorch.