patrick-kidger / torchtyping

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Apache License 2.0
1.41k stars 34 forks source link

TensorType detail: grad_enabled #21

Open SimpleConjugate opened 3 years ago

SimpleConjugate commented 3 years ago

Is it possible to perform type checking for tensors with grad enabled? I myself am not sure of all the cases necessary to test against to confirm this as I don't fully understand how runtime type checking operates.

class _AutoGradTensorDetail(TensorDetail):
    def check(self, tensor: torch.Tensor)  -> bool:
        return tensor.requires_grad()
patrick-kidger commented 3 years ago

Ah, that's a nice idea for a tensor detail.

Yes, that should be completely possible. Quick mock-up (untested):

class _RequiresGradDetail(TensorDetail):
    def check(self, tensor: Tensor) -> bool:
        return tensor.requires_grad

    def __repr__(self) -> str:
        return "requires_grad"

    @classmethod
    def tensor_repr(cls, tensor: Tensor) -> str:
        if tensor.requires_grad:
            return "requires_grad"
        else:
            return ""

requires_grad = _RequiresGradDetail()