patrick-kidger / torchtyping

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Apache License 2.0
1.39k stars 33 forks source link

Type checking based on names #36

Closed tombosc closed 2 years ago

tombosc commented 2 years ago

Hello

I'd like to know if there's an easy way to check tensors by name:

import torch
from torch import rand
from torchtyping import TensorType, patch_typeguard, is_named
from typeguard import typechecked

patch_typeguard()  # use before @typechecked

def test():
    t = Test()
    b = t.return_batch()
    o = t.return_other()
    v = t.func(b, o)
    u = t.func(o, b)  # can we have it raise TypeError?

class Test:
    def __init__(self):
        pass

    @typechecked
    def func(
        self,
        x: TensorType["batch"],
        y: TensorType["other"],
    ) -> TensorType["batch", "other"]:
        return torch.outer(x, y)

    def return_batch(self) -> TensorType["batch"]:
        return rand(4)

    def return_other(self) -> TensorType["other"]:
        return rand(3)

test()

Right now, IIUC only dimensions are checked, so in this example there is no error...

I think that I could use is_named in TensorType, but it gets very cumbersome because we also need to use names=... everytime we declare a tensor. This could be OK... but it can get even worse because some pytorch operations don't seem to work with named tensors (outer here! at least with 1.9.1) so we need to rename tensors every 2 lines...

Is is doable to have patch_typeguard(name_check=True), or would it be too complicated to implement? (I think basically I want nominal typing instead of structural typing)

Thanks for your work!

patrick-kidger commented 2 years ago

I'm afraid named tensors (and TorchTyping's support for them via is_named) is probably as good as this gets right now.

This kind of type-checking for array/tensor types is definitely desirable, but not something that can be easily patched in to an existing system like PyTorch -- it'd probably need support in PyTorch itself. (Probably by extending named tensors.)

I think the best support for something like this is currently found in Dex.

tombosc commented 2 years ago

Hey, thanks for the quick answer :) I'm going to look at Dex. And do you know if there is anything that could do that in Jax? (Feel free to close! and thanks again)

patrick-kidger commented 2 years ago

Hmm I think there might be a library like this for JAX, but I don't recall which one it is. (There's a few libraries like this floating around, offering various different kinds of checking.)