pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.26k stars 6.96k forks source link

Transforms with nested tensor #7761

Open agunapal opened 1 year ago

agunapal commented 1 year ago

🚀 The feature

For batched inference on images of different sizes, we need to do the following

It would be nice to do the following instead

Motivation, pitch

This would result in improved performance for image pre-processing

Alternatives

Additional context

No response

AnimeshMaheshwari22 commented 1 year ago

Hello @agunapal this is quite interesting! This will be a part of transforms right?

agunapal commented 1 year ago

If this is possible, yes.. this would need to be supported by transforms to handle pre-processing bottleneck in inference

NicolasHug commented 1 year ago

Hi @agunapal , thanks for the feature request.

I understand that in general, processing input in batches makes the transforms faster. And I also acknowledge that passing batches of images to Resize is pretty much impossible as-is, because.. well, we can't batch images of different sizes.

So that's where NestedTensor comes in, as it provides a nice UX to manipulate tensors of different sizes. But unfortunately, I'm afraid NestedTensor won't help regarding perf. There aren't a lot of torch operations that natively support Nestedtensor, and in particular torch.nn.functional.interpolate isn't one of them: that's what Resize() relies on. So technically even if there was support for NestedTensor in torchvision, we wouldn't be able to do much more than just manually loop over the entries of the NestedTensor and pass them one-by-one to interpolate().

Regarding the UX: I tried to see if our V2 transforms could natively support NestedTensor. It's pretty much the same story as for TensorDict (https://github.com/pytorch/vision/issues/7763): NestedTensor don't integrate too well with pytree, so right now nothing really works. Maybe the NestedTensor devs would be open to integrate it with pytree?

(Thinking about all this made me open https://github.com/pytorch/vision/issues/7774, which is partially related to what you want to do. But I still don't fix it will be a silver bullet, sorry :/ )