patrick-kidger / torchtyping

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Apache License 2.0
1.41k stars 34 forks source link

TorchScript compatibility? #13

Open Linux-cpp-lisp opened 3 years ago

Linux-cpp-lisp commented 3 years ago

Hi all,

This library looks very nice :)

Is TensorType compatible with the TorchScript compiler? As in, are the annotations transparently converted to torch.Tensor as far as torch.jit.script is concerned, allowing annotated modules/functions to be compiled? (I'm not worried about whether the type checking applied in TorchScript, just whether an annotated program that gets shape-checked in Python can be compiled down to TorchScript.)

Thanks!

patrick-kidger commented 3 years ago

So I've been playing with this for a bit and unfortunately can't get it to work.

If you or someone else does manage to get this working, then I'd be happy to accept a PR on it.

For posterity:

Linux-cpp-lisp commented 3 years ago

Hi @patrick-kidger, thanks for he quick answer! This level of arcane tinkering with TorchScript definitely sounds familiar to me... :grin:

The issue you link in the third bullet does make it look like there is nothing that can be done here until PyTorch resolves the underlying incompatibility with Python. (If I'm understanding this right you couldn't even do Annotated[torch.Tensor, something_else] since it wouldn't be parsable as a string, even though Python people worked hard to make Annotated backwards compatible.) Hopefully the PyTorch people are going to start using Python inspection for this like they said in the linked issue.

EDIT: it looks like fixes to this may have been merged? unclear: https://github.com/pytorch/pytorch/pull/29623

patrick-kidger commented 3 years ago

Haha!

To answer the question, I agree that seems unclear on whether or not that issue is fixed. Either way, because of that or some other issue, our end use case doesn't seem to working at the moment.

kharitonov-ivan commented 3 years ago

Hi! Is there any updates about that, guys?

patrick-kidger commented 3 years ago

Not that I know about. As far as I know this is still a limitation in torchscript itself.

If this is a priority for you then you might like to try bringing this up with the torchscript team. They might know more about any possibilities for making this work.

martenlienen commented 2 years ago

I have found a workaround. Let's say you have the following function

def f(x: TensorType["batch", "feature"]):
    return x.sum()

which you want to use in TorchScript. TorchScript does not like generic types in signatures, but we want to keep the dimension annotations somwhere for documentation purposes. We can work around this with a subclass.

import torch
from torchtyping import TensorType

class BatchedFeatureTensor(TensorType["batch", "feature"]):
    pass

@torch.jit.script
def f(x: BatchedFeatureTensor):
    return x.sum()

print(f(torch.tensor([[-1.0, 2.0, 1.2]])))
print(f.code)

# => tensor(2.2000)
# => def f(x: Tensor) -> Tensor:
# =>   return torch.sum(x)
Datasciensyash commented 2 years ago

Found another way to deal with torchscript. Just paste the code and call patch_torchscript() before exporting.

import re
import typing as tp

import torch

ttp_regexp = re.compile(r"TensorType\[[^\]]*\]")
torchtyping_replacer = "torch.Tensor"

def _replace_torchtyping(source_lines: tp.List[str]) -> tp.List[str]:

    # Join all lines
    cat_lines = "".join(source_lines)

    # Quick exit, if torchtyping is not used
    if ttp_regexp.search(cat_lines) is None:
        return source_lines

    # Replace TensorType
    cat_lines = ttp_regexp.sub(torchtyping_replacer, cat_lines)

    # Split into lines
    source_lines = cat_lines.split("\n")
    source_lines = [f"{i}\n" for i in source_lines]

    return source_lines

def _torchtyping_destruct_wrapper(func: tp.Callable) -> tp.Callable:
    def _wrap_func(obj: tp.Any, error_msg: tp.Optional[str] = None) -> tp.Tuple[tp.List[str], int, tp.Optional[str]]:
        srclines, file_lineno, filename = func(obj, error_msg)
        srclines = _replace_torchtyping(srclines)
        return srclines, file_lineno, filename

    return _wrap_func

def patch_torchscript() -> None:
    """
    Patch torchscript to work with torchtyping.

    Returns: None.

    """
    # Patch _sources if torch >= 1.10.0, else torch.jit.frontend
    if hasattr(torch, "_sources"):
        src = getattr(torch, "_sources")  # noqa: B009
    else:
        src = getattr(torch.jit, "frontend")  # noqa: B009

    src.get_source_lines_and_file = _torchtyping_destruct_wrapper(src.get_source_lines_and_file)