Open corwinjoy opened 1 year ago
Thanks for the PR! Unfortunately, this isn't quite the direction I had in mind.
Following on from the discussion in #39, perhaps it's worth making clear that I don't intend to make TorchTyping depend on JAX. Rather, that the plan is to simply copy over the non-JAX parts of the code. (Which is most of it.)
The idea would be to end up with annotations that look like Float[Tensor, "batch channel", ...]
, where ...
is information about those PyTorch concepts that don't exist in JAX. (In particular, device and layout). And then add in some backward-compatible aliases, so that TensorType[...]
is lowered to this new representation.
At a technical level this should be essentially simple. The main hurdle - and the reason I've been putting off doing this is - is writing up documentation that makes this transition clear.
Thanks for the clarification! I can totally see why you want to pull over the jaxtyping code and have a single code base. I understand that this PR is perhaps not what you were looking for, but I think it could actually represent a very important step in generalizing what you have and maybe even merging the two code bases. Let's take an example snippet from jaxtyping where the dtype is extracted (array_types.py: 129-)
class _MetaAbstractArray(type):
def __instancecheck__(cls, obj):
if not isinstance(obj, cls.array_type):
return False
if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"):
# JAX, numpy
dtype = obj.dtype.type.__name__
elif hasattr(obj.dtype, "as_numpy_dtype"):
# TensorFlow
dtype = obj.dtype.as_numpy_dtype.__name__
else:
# PyTorch
repr_dtype = repr(obj.dtype).split(".")
if len(repr_dtype) == 2 and repr_dtype[0] == "torch":
dtype = repr_dtype[1]
else:
raise RuntimeError(
"Unrecognised array/tensor type to extract dtype from"
)
if cls.dtypes is not _any_dtype and dtype not in cls.dtypes:
return False
I think you would agree that it's a bit awkward and somewhat hard to extend since the supported classes have to be coded in advance. Instead, with a design like in this PR, we could make _MetaAbstractArray use a protocol class to declare what kinds of properties it is checking at runtime. For concreteness let's say we define this protocol like
class _ArrayLike(Protocol)
@property
def dtype(self) -> AbstractDtype:
pass
...
Then, to type-check a concrete class like numpy.array or torch.Tensor we just use the adapter pattern to map the specialized methods to the interface. (As an example, a simple name remapper: Adapter Method – Python Design Patterns).
This would make it easy for folks like me to extend your library to array-type objects such as LinearOperator by just writing an adapter to the interface specified by the library.
In addition, I think it could also let you merge these two libraries and make your life easier. You wrote that: "Yep, I did consider merging things. This ends up being infeasible due to edge cases, e.g. JAX and PyTorch (eventually) having different PyTree manipulation routines." Looking at the pytree_type.py it seems plausible that you could define a protocol class along the lines of _TreeLike(Protocol) with methods for accessing leaves. Then, JAX Pytree support can be done via an adapter and this can be an optional import for those that don't want a JAX dependency. Anyway, I think this could be pretty nice and would be happy to help make it happen. I don't thoroughly understand the jaxtyping code but I think this is doable and would be happy to help with documentation!
Hmm. I suppose the practical implementation of such an adaptor would be via a registry:
import functools as ft
@ft.singledispatch
def get_dtype(obj):
# Note that this default implementation does not explicitly
# depend on any of PyTorch/etc; thus the singledispatch
# hook is made available just for the sake of user-defined
# custom types.
if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"):
# JAX, numpy
dtype = obj.dtype.type.__name__
elif hasattr(obj.dtype, "as_numpy_dtype"):
# TensorFlow
dtype = obj.dtype.as_numpy_dtype.__name__
else:
# PyTorch
repr_dtype = repr(obj.dtype).split(".")
if len(repr_dtype) == 2 and repr_dtype[0] == "torch":
dtype = repr_dtype[1]
else:
raise RuntimeError(
"Unrecognised array/tensor type to extract dtype from"
)
class _MetaAbstractArray(type):
def __instancecheck__(cls, obj):
...
dtype = get_dtype(obj)
...
and then in your user code, you could add a custom overload for your type.
I'd be willing to accept a PR for this over in jaxtyping.
As promised, here is the PR to upgrade the library to define a 'torch-like' protocol and use that for the base type rather than using torch.Tensor directly. This lets users perform dimension checking on classes that support a Tensor interface but do not directly inherit from torch.Tensor. I think the change is fairly clear-cut, I have added a test case to demonstrate and verify that dimensions are actually checked. The only question I have is about the change to line 304 in typechecker.py (the last change below). Is this test really necessary? I had to change it to use default construction because protocols don't support isinstance if they have properties.