Closed tombosc closed 2 years ago
I'm afraid named tensors (and TorchTyping's support for them via is_named
) is probably as good as this gets right now.
This kind of type-checking for array/tensor types is definitely desirable, but not something that can be easily patched in to an existing system like PyTorch -- it'd probably need support in PyTorch itself. (Probably by extending named tensors.)
I think the best support for something like this is currently found in Dex.
Hey, thanks for the quick answer :) I'm going to look at Dex. And do you know if there is anything that could do that in Jax? (Feel free to close! and thanks again)
Hmm I think there might be a library like this for JAX, but I don't recall which one it is. (There's a few libraries like this floating around, offering various different kinds of checking.)
Hello
I'd like to know if there's an easy way to check tensors by name:
Right now, IIUC only dimensions are checked, so in this example there is no error...
I think that I could use
is_named
in TensorType, but it gets very cumbersome because we also need to usenames=...
everytime we declare a tensor. This could be OK... but it can get even worse because some pytorch operations don't seem to work with named tensors (outer
here! at least with 1.9.1) so we need to rename tensors every 2 lines...Is is doable to have
patch_typeguard(name_check=True)
, or would it be too complicated to implement? (I think basically I want nominal typing instead of structural typing)Thanks for your work!