patrick-kidger / torchtyping

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Apache License 2.0
1.39k stars 33 forks source link

Tensor duck #40

Open corwinjoy opened 1 year ago

corwinjoy commented 1 year ago

As promised, here is the PR to upgrade the library to define a 'torch-like' protocol and use that for the base type rather than using torch.Tensor directly. This lets users perform dimension checking on classes that support a Tensor interface but do not directly inherit from torch.Tensor. I think the change is fairly clear-cut, I have added a test case to demonstrate and verify that dimensions are actually checked. The only question I have is about the change to line 304 in typechecker.py (the last change below). Is this test really necessary? I had to change it to use default construction because protocols don't support isinstance if they have properties.

patrick-kidger commented 1 year ago

Thanks for the PR! Unfortunately, this isn't quite the direction I had in mind.

Following on from the discussion in #39, perhaps it's worth making clear that I don't intend to make TorchTyping depend on JAX. Rather, that the plan is to simply copy over the non-JAX parts of the code. (Which is most of it.)

The idea would be to end up with annotations that look like Float[Tensor, "batch channel", ...], where ... is information about those PyTorch concepts that don't exist in JAX. (In particular, device and layout). And then add in some backward-compatible aliases, so that TensorType[...] is lowered to this new representation.

At a technical level this should be essentially simple. The main hurdle - and the reason I've been putting off doing this is - is writing up documentation that makes this transition clear.

corwinjoy commented 1 year ago

Thanks for the clarification! I can totally see why you want to pull over the jaxtyping code and have a single code base. I understand that this PR is perhaps not what you were looking for, but I think it could actually represent a very important step in generalizing what you have and maybe even merging the two code bases. Let's take an example snippet from jaxtyping where the dtype is extracted (array_types.py: 129-)

class _MetaAbstractArray(type):
    def __instancecheck__(cls, obj):
        if not isinstance(obj, cls.array_type):
            return False

        if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"):
            # JAX, numpy
            dtype = obj.dtype.type.__name__
        elif hasattr(obj.dtype, "as_numpy_dtype"):
            # TensorFlow
            dtype = obj.dtype.as_numpy_dtype.__name__
        else:
            # PyTorch
            repr_dtype = repr(obj.dtype).split(".")
            if len(repr_dtype) == 2 and repr_dtype[0] == "torch":
                dtype = repr_dtype[1]
            else:
                raise RuntimeError(
                    "Unrecognised array/tensor type to extract dtype from"
                )

        if cls.dtypes is not _any_dtype and dtype not in cls.dtypes:
            return False

I think you would agree that it's a bit awkward and somewhat hard to extend since the supported classes have to be coded in advance. Instead, with a design like in this PR, we could make _MetaAbstractArray use a protocol class to declare what kinds of properties it is checking at runtime. For concreteness let's say we define this protocol like

class _ArrayLike(Protocol)
    @property
    def dtype(self) -> AbstractDtype:
        pass
...

Then, to type-check a concrete class like numpy.array or torch.Tensor we just use the adapter pattern to map the specialized methods to the interface. (As an example, a simple name remapper: Adapter Method – Python Design Patterns).

This would make it easy for folks like me to extend your library to array-type objects such as LinearOperator by just writing an adapter to the interface specified by the library.

In addition, I think it could also let you merge these two libraries and make your life easier. You wrote that: "Yep, I did consider merging things. This ends up being infeasible due to edge cases, e.g. JAX and PyTorch (eventually) having different PyTree manipulation routines." Looking at the pytree_type.py it seems plausible that you could define a protocol class along the lines of _TreeLike(Protocol) with methods for accessing leaves. Then, JAX Pytree support can be done via an adapter and this can be an optional import for those that don't want a JAX dependency. Anyway, I think this could be pretty nice and would be happy to help make it happen. I don't thoroughly understand the jaxtyping code but I think this is doable and would be happy to help with documentation!

patrick-kidger commented 1 year ago

Hmm. I suppose the practical implementation of such an adaptor would be via a registry:

import functools as ft

@ft.singledispatch
def get_dtype(obj):
    # Note that this default implementation does not explicitly
    # depend on any of PyTorch/etc; thus the singledispatch
    # hook is made available just for the sake of user-defined
    # custom types.
    if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"):
        # JAX, numpy
        dtype = obj.dtype.type.__name__
    elif hasattr(obj.dtype, "as_numpy_dtype"):
        # TensorFlow
        dtype = obj.dtype.as_numpy_dtype.__name__
    else:
        # PyTorch
        repr_dtype = repr(obj.dtype).split(".")
        if len(repr_dtype) == 2 and repr_dtype[0] == "torch":
            dtype = repr_dtype[1]
        else:
            raise RuntimeError(
                "Unrecognised array/tensor type to extract dtype from"
            )

class _MetaAbstractArray(type):
    def __instancecheck__(cls, obj):
        ...
        dtype = get_dtype(obj)
        ...

and then in your user code, you could add a custom overload for your type.

I'd be willing to accept a PR for this over in jaxtyping.