pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.02k stars 6.93k forks source link

torchvision transforms for video and sequence data #7476

Closed JohannesTheo closed 1 year ago

JohannesTheo commented 1 year ago

📚 The doc issue

Hello everyone,

I'd like to discuss how torchvision.transforms are applied to video and more generally, to sequence data. My examples will refer to torchvision.transforms.Resize which I have tested but it might apply to other transforms as well.

Before I present you the case, let me be clear that I do not search for an imediate solution to my problem! I'm well aware of alternatives such as PyTorchVideo, custom transforms, etc. The idea of this post is to discuss the topic with the community first, before I work on a pull request, i.e. to ensure I'm not completely wrong on this (which I well might be :D).

So, lets start with the documentation of torchvision.transforms.Resize which states:

Resize the input image to the given size. If the image is torch Tensor, it is expected to have […, H, W] shape, where … means an arbitrary number of leading dimensions

I think this solution is quite elegant for sequences but I found it not working as expected in some cases. Consider the following data:

import torch

video = torch.rand(size=[25, 3, 64, 64])
masks = torch.rand(size=[25, 10, 1, 64, 64])

In this case, video is a normal sequence of 64x64 RGB images in TCHW format. Similarly, masks is a sequence of corresponding object masks for 10 objects, each represented as 64x64 grayscale image, so TMCHW. Now if we want to resize the video, we have to resize the masks sequence as well and according to the docs, we can do something like this:

from torchvision import transforms

resize = transforms.Resize(size=32) # or size=(32,32) 

print(resize(video).shape)
# > torch.Size([25, 3, 32, 32])

print(resize(masks).shape)
# > ValueError: Input and output must have the same number of spatial dimensions, but got input with spatial dimensions of [1, 64, 64] and output size of [32, 32]. Please provide input tensor in (N, C, d1, d2, ...,dK) format and output size in (o1, o2, ...,oK) format.

As you can see, this works for the normal video sequence but not for the special masks sequence which has an extra dimension. This was a little surprising because the docs promised it will work with "an arbitrary number of leading dimensions".

In response, I did some digging and the problem arises from the following CallStack:

  1. torchvision.transforms.Resize calls
  2. torchvision.transforms.functional.resize which first calculates
  3. torchvision.transforms.functional._compute_resized_output_size (this always returns [new_h, new_w], so from here on we are guaranteed to have a 2D size = list[int,int]), and then calls
  4. torchvision.transforms._functional_tensor.resize which finally calls
  5. torch.nn.functional.interpolate with size=[new_h, new_w] from step 3.

Now, the problem I like to discuss (and the reason the docs are a little misleading) is that torch.nn.functional.interpolate assumes the input to be:

mini-batch x channels x [optional depth] x [optional height] x width

Because of this assumption (two leading mandatory dimensions) it calculates:

dim = input.dim() - 2  # Number of spatial dimensions.

in line 3865 and then checks in line 3877:

if len(size) != dim:
    raise ValueError(...)

which will raise the error mentioned. To fully understand whats going on, let's recreate the cases from above:

import torch
from torchvision import transforms

size  = (32, 32)
video = torch.rand(size=[25, 3, 64, 64])
masks = torch.rand(size=[25, 10, 1, 64, 64])

torch.nn.functional.interpolate(video, size=size)
# this works because: (video.dim() - 2 == len(size) -> 2 == 2

torch.nn.functional.interpolate(masks, size=size)
# this does't work because: (masks.dim() - 2 != len(size) -> 3 != 2

# a working solution could be:
print(torch.nn.functional.interpolate(masks, size=(1, 32, 32)).shape)
# > torch.Size([25, 10, 1, 32, 32])

# however:
resize = transforms.Resize(size=(1, 32, 32))
# > ValueError: If size is a sequence, it should have 1 or 2 values

To summarize, it works for 4D video sequences (by accident?) because interpolate will interpret the sequence dimension as batch dimension. For higher dimensional sequences, like in my masks example, it breaks.

It could work though, it is currently just limited by the size check in torchvison.transforms.Resize and/or by the dim assumptions made in torch.nn.functional.interpolate

Suggest a potential alternative/fix

What to do?

  1. Update the documentation? - easy fix, but I actually want this to work as described which is the whole point of this lengthy post :D
  2. Update the dimension check in: torch.nn.functional.interpolate?
  3. Update how the torchvison.transform stack handles video data?

What do you think?

From what I can see, _functional_video.py is deprecated and v2/functional/_geometry.py has a function resize_video which just calls resize_image_tensor which again, calls torch.nn.functional.interpolate with a 2D size and will therefore, suffer from the same problem. I'm also aware of https://pytorchvideo.org but they require CTHW format and add OpenCV as a dependency.

Personally, I think the solution as described in the current documentation is most elegant and least limiting. Also, a pull request should not be too hard to make it work as promised. On the other hand, I just checked this for Resize and it should probably be checked for other transforms as well...

If you made it this far, thx for reading :) I'd really appreciate your input on this!

pmeier commented 1 year ago

Thanks @JohannesTheo for the detailed write-up! Indeed, the documentation is wrong since we only support up three and four dimensions:

https://github.com/pytorch/vision/blob/0387b8821d67ca62d57e3b228ade45371c0af79d/torchvision/transforms/_functional_tensor.py#L516

We just released torchvision.transforms.v2, that no longer has this issue. Just replacing from torchvision import transforms above with import torchvision.transforms.v2 as transforms and re-running, prints:

torch.Size([25, 3, 32, 32])
torch.Size([25, 10, 1, 32, 32])

Additionally, one thing that you didn't consider so far is that Resize by default uses bilinear interpolation

https://github.com/pytorch/vision/blob/0387b8821d67ca62d57e3b228ade45371c0af79d/torchvision/transforms/transforms.py#L337

and that is not well suited for masks. You need nearest interpolation for that. The v2 API handles this for you automatically for you, if you wrap your tensors in what we call datapoints:

from torchvision import datapoints

video = datapoints.Video(video)
masks = datapoints.Mask(masks)

This will automatically select the right kernel for each type, which means bilinear interpolation for videos

https://github.com/pytorch/vision/blob/0387b8821d67ca62d57e3b228ade45371c0af79d/torchvision/transforms/v2/functional/_geometry.py#L247-L250

and nearest for masks

https://github.com/pytorch/vision/blob/0387b8821d67ca62d57e3b228ade45371c0af79d/torchvision/transforms/v2/functional/_geometry.py#L225

As an added bonus, you can also transform both inputs at the same time:

video2, masks2 = resize(video, masks)
print(video2.shape, masks2.shape)
torch.Size([25, 3, 32, 32]) torch.Size([25, 10, 1, 32, 32])

That was a lot of text to say: we are focusing the v2 API in the future. Since it already has the behavior you ask for, I think it is best to just fix the docstring of the v1 API. If you want to send a PR, go for it!

JohannesTheo commented 1 year ago

Hey @pmeier, thanks for the quick response and the detailed explanation! I haven't checked the v2 API in detail but this looks really great. In particular the mask interpolation part that I missed indeed 😅 I'll go for a PR for the v1 API docstring.