Open botev opened 3 months ago
I'd suggest raising this with the PyTorch folks, including a MWE. This is likely this is an instance of hitting something torch.compile
doesn't support yet. Another example came up recently at https://github.com/pytorch/pytorch/issues/122093
Hmm, I managed to fix a few things and rearrange, but now I get:
The problem arose whilst typechecking parameter 'self'.
Actual value: MyModuel(....)
Expected type: <class 'inspect._empty'>.
which I would guess is because the compile rewrite the forward pass as a pure function?
I'm not sure of the details of torch.compile
, so hard for me to speculate I'm afraid.
Hi @botev , I face exactly the same (original) problem. Did you find any solution?
Unfortunately no, I just disabled the guard for PyTorch.
We are very happy with the fact that jaxtyping supports Pytorch as well, but we are currently hitting some kind of weird error/edge case and was hoping if you can give some suggestions. When compiling a module and trying to run it we get this stacktrace: