Closed JohannesTheo closed 1 year ago
Thanks @JohannesTheo for the detailed write-up! Indeed, the documentation is wrong since we only support up three and four dimensions:
We just released torchvision.transforms.v2
, that no longer has this issue. Just replacing from torchvision import transforms
above with import torchvision.transforms.v2 as transforms
and re-running, prints:
torch.Size([25, 3, 32, 32])
torch.Size([25, 10, 1, 32, 32])
Additionally, one thing that you didn't consider so far is that Resize
by default uses bilinear interpolation
and that is not well suited for masks. You need nearest interpolation for that. The v2 API handles this for you automatically for you, if you wrap your tensors in what we call datapoints:
from torchvision import datapoints
video = datapoints.Video(video)
masks = datapoints.Mask(masks)
This will automatically select the right kernel for each type, which means bilinear interpolation for videos
and nearest for masks
As an added bonus, you can also transform both inputs at the same time:
video2, masks2 = resize(video, masks)
print(video2.shape, masks2.shape)
torch.Size([25, 3, 32, 32]) torch.Size([25, 10, 1, 32, 32])
That was a lot of text to say: we are focusing the v2 API in the future. Since it already has the behavior you ask for, I think it is best to just fix the docstring of the v1 API. If you want to send a PR, go for it!
Hey @pmeier, thanks for the quick response and the detailed explanation! I haven't checked the v2 API in detail but this looks really great. In particular the mask interpolation part that I missed indeed 😅 I'll go for a PR for the v1 API docstring.
📚 The doc issue
Hello everyone,
I'd like to discuss how
torchvision.transforms
are applied to video and more generally, to sequence data. My examples will refer totorchvision.transforms.Resize
which I have tested but it might apply to other transforms as well.Before I present you the case, let me be clear that I do not search for an imediate solution to my problem! I'm well aware of alternatives such as PyTorchVideo, custom transforms, etc. The idea of this post is to discuss the topic with the community first, before I work on a pull request, i.e. to ensure I'm not completely wrong on this (which I well might be :D).
So, lets start with the documentation of torchvision.transforms.Resize which states:
I think this solution is quite elegant for sequences but I found it not working as expected in some cases. Consider the following data:
In this case,
video
is a normal sequence of 64x64 RGB images in TCHW format. Similarly,masks
is a sequence of corresponding object masks for 10 objects, each represented as 64x64 grayscale image, so TMCHW. Now if we want to resize the video, we have to resize the masks sequence as well and according to the docs, we can do something like this:As you can see, this works for the normal video sequence but not for the special masks sequence which has an extra dimension. This was a little surprising because the docs promised it will work with "an arbitrary number of leading dimensions".
In response, I did some digging and the problem arises from the following CallStack:
torchvision.transforms.Resize
callstorchvision.transforms.functional.resize
which first calculatestorchvision.transforms.functional._compute_resized_output_size
(this always returns[new_h, new_w]
, so from here on we are guaranteed to have a 2Dsize = list[int,int]
), and then callstorchvision.transforms._functional_tensor.resize
which finally callstorch.nn.functional.interpolate
withsize=[new_h, new_w]
from step 3.Now, the problem I like to discuss (and the reason the docs are a little misleading) is that
torch.nn.functional.interpolate
assumes the input to be:Because of this assumption (two leading mandatory dimensions) it calculates:
in line 3865 and then checks in line 3877:
which will raise the error mentioned. To fully understand whats going on, let's recreate the cases from above:
To summarize, it works for 4D video sequences (by accident?) because
interpolate
will interpret the sequence dimension as batch dimension. For higher dimensional sequences, like in my masks example, it breaks.It could work though, it is currently just limited by the size check in
torchvison.transforms.Resize
and/or by the dim assumptions made intorch.nn.functional.interpolate
Suggest a potential alternative/fix
What to do?
torch.nn.functional.interpolate
?torchvison.transform
stack handles video data?What do you think?
From what I can see,
_functional_video.py
is deprecated andv2/functional/_geometry.py
has a functionresize_video
which just callsresize_image_tensor
which again, callstorch.nn.functional.interpolate
with a 2D size and will therefore, suffer from the same problem. I'm also aware of https://pytorchvideo.org but they require CTHW format and add OpenCV as a dependency.Personally, I think the solution as described in the current documentation is most elegant and least limiting. Also, a pull request should not be too hard to make it work as promised. On the other hand, I just checked this for
Resize
and it should probably be checked for other transforms as well...If you made it this far, thx for reading :) I'd really appreciate your input on this!