Open Linux-cpp-lisp opened 3 years ago
So I've been playing with this for a bit and unfortunately can't get it to work.
If you or someone else does manage to get this working, then I'd be happy to accept a PR on it.
For posterity:
TensorType
does not currently inherit from torch.Tensor
. This means that @torch.jit.script def f(x: TensorType)
results in TensorType
trying to be compiled, which fails.@torch.jit.ignore
in various places. I think ignoring things only really works for free functions or methods of subclasses of torch.nn.Module
.TensorType
to inherit from torch.Tensor
allows @torch.jit.script def f(x: TensorType)
, but @torch.jit.script def f(x: TensorType["b"])
still breaks, with error message Unknown type constructor TensorType
from the TorchScript compiler.class TensorType(typing.List)
, without success.) My impression is that the only parameterised types admitted as annotations are the standard built-in ones like List
.Hi @patrick-kidger, thanks for he quick answer! This level of arcane tinkering with TorchScript definitely sounds familiar to me... :grin:
The issue you link in the third bullet does make it look like there is nothing that can be done here until PyTorch resolves the underlying incompatibility with Python. (If I'm understanding this right you couldn't even do Annotated[torch.Tensor, something_else]
since it wouldn't be parsable as a string, even though Python people worked hard to make Annotated
backwards compatible.) Hopefully the PyTorch people are going to start using Python inspection for this like they said in the linked issue.
EDIT: it looks like fixes to this may have been merged? unclear: https://github.com/pytorch/pytorch/pull/29623
Haha!
To answer the question, I agree that seems unclear on whether or not that issue is fixed. Either way, because of that or some other issue, our end use case doesn't seem to working at the moment.
Hi! Is there any updates about that, guys?
Not that I know about. As far as I know this is still a limitation in torchscript itself.
If this is a priority for you then you might like to try bringing this up with the torchscript team. They might know more about any possibilities for making this work.
I have found a workaround. Let's say you have the following function
def f(x: TensorType["batch", "feature"]):
return x.sum()
which you want to use in TorchScript. TorchScript does not like generic types in signatures, but we want to keep the dimension annotations somwhere for documentation purposes. We can work around this with a subclass.
import torch
from torchtyping import TensorType
class BatchedFeatureTensor(TensorType["batch", "feature"]):
pass
@torch.jit.script
def f(x: BatchedFeatureTensor):
return x.sum()
print(f(torch.tensor([[-1.0, 2.0, 1.2]])))
print(f.code)
# => tensor(2.2000)
# => def f(x: Tensor) -> Tensor:
# => return torch.sum(x)
Found another way to deal with torchscript. Just paste the code and call patch_torchscript()
before exporting.
import re
import typing as tp
import torch
ttp_regexp = re.compile(r"TensorType\[[^\]]*\]")
torchtyping_replacer = "torch.Tensor"
def _replace_torchtyping(source_lines: tp.List[str]) -> tp.List[str]:
# Join all lines
cat_lines = "".join(source_lines)
# Quick exit, if torchtyping is not used
if ttp_regexp.search(cat_lines) is None:
return source_lines
# Replace TensorType
cat_lines = ttp_regexp.sub(torchtyping_replacer, cat_lines)
# Split into lines
source_lines = cat_lines.split("\n")
source_lines = [f"{i}\n" for i in source_lines]
return source_lines
def _torchtyping_destruct_wrapper(func: tp.Callable) -> tp.Callable:
def _wrap_func(obj: tp.Any, error_msg: tp.Optional[str] = None) -> tp.Tuple[tp.List[str], int, tp.Optional[str]]:
srclines, file_lineno, filename = func(obj, error_msg)
srclines = _replace_torchtyping(srclines)
return srclines, file_lineno, filename
return _wrap_func
def patch_torchscript() -> None:
"""
Patch torchscript to work with torchtyping.
Returns: None.
"""
# Patch _sources if torch >= 1.10.0, else torch.jit.frontend
if hasattr(torch, "_sources"):
src = getattr(torch, "_sources") # noqa: B009
else:
src = getattr(torch.jit, "frontend") # noqa: B009
src.get_source_lines_and_file = _torchtyping_destruct_wrapper(src.get_source_lines_and_file)
Hi all,
This library looks very nice :)
Is
TensorType
compatible with the TorchScript compiler? As in, are the annotations transparently converted totorch.Tensor
as far astorch.jit.script
is concerned, allowing annotated modules/functions to be compiled? (I'm not worried about whether the type checking applied in TorchScript, just whether an annotated program that gets shape-checked in Python can be compiled down to TorchScript.)Thanks!