Closed albanD closed 3 months ago
Unfortunately torchtyping has long been deprecated in favour of jaxtyping -- which despite the historical name, does not depend on JAX and does support PyTorch. And does far fewer evil things under the hood.
So I'd really recommend that the issue reporter update to using jaxtyping instead!
(FWIW I can see that torchtyping is appearing in a dependency, and it might be tricky to update. If it is easier then I'd be happy to take a PR here.
I think the mixin is needed to make some IDEs work in some unusual way, but that's not important.)
Hello all, which version of Pytorch should I roll back to to solve this issue?
torch < 2.4 should work in terms of releases. For nightlies would need to check a bit more carefully
Sorry for the churn here. If you think this is a going to lead to a lot of churn and cannot be worked around in dependencies, I'm willing to consider reverting that change in PyTorch, I wasn't expecting anyone to use this by side effect :/
I'm happy either way! I can press the button on a new release of torchtyping if need be.
Thanks for the quick turnaround @patrick-kidger !
When using the latest version of PyTorch, users are seeing the following error: "RuntimeError: Cannot subclass _TensorBase directly"
https://github.com/pytorch/pytorch/issues/131463
This happens because PyTorch recently started forbidding subclassing from the raw c++ _TensorBase type in favor of subclassing torch.Tensor directly. Subclassing the base type is very unsafe as many methods on it actually assume attributes from Tensor are there and so it is not really a valid class to use ever.
Unfortunately, the trick used in https://github.com/patrick-kidger/torchtyping/blob/1f3749c5b5617ec6b6449e98ad9ae3fb6645ef54/torchtyping/tensor_type.py#L32 to create the mixing based on
type(torch.Tensor)
makes it so that the meta type sees a new class of that metatype being created withouttorch.Tensor
being the base class. Hence triggering the error.I'm not sure why we need the extra step of creating the mixin and then having the final class inherit from Tensor here https://github.com/patrick-kidger/torchtyping/blob/1f3749c5b5617ec6b6449e98ad9ae3fb6645ef54/torchtyping/tensor_type.py#L176 but avoiding the mixin and have the base class be
class TensorType(torch.Tensor, metaclass=_TensorTypeMeta):
would avoid this issue. Not sure if that is a proper fix though.