pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.11k stars 6.94k forks source link

MultiScaleRoIAlign.infer_scale fails on very oblong images #3747

Open quasimik opened 3 years ago

quasimik commented 3 years ago

🐛 Bug

If the first given image is very tall or very wide, MultiScaleRoIAlign.infer_scale throws an AssertionError.

To Reproduce

in test.py:

import torch
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.ops import MultiScaleRoIAlign

transform = GeneralizedRCNNTransform(
    min_size=100,
    max_size=100,
    image_mean=(0.5, 0.5, 0.5),
    image_std=(0.1, 0.1, 0.1)
)
backbone = resnet_fpn_backbone(backbone_name='resnet50', pretrained=True)

for width in range(100, 1000, 100):
    print(f'image size: H 100 * W {width}')
    images = [torch.rand((3, 100, width))]
    images, _ = transform(images)  # ImageList
    features = backbone(images.tensors)

    # instantiate MultiScaleRoIAlign here because the code of interest lies
    #   inside MultiScaleRoIAlign.infer_scale, which only runs once per instance
    pool = MultiScaleRoIAlign(
        featmap_names=['0', '1', '2', '3'],
        output_size=(7, 7),
        sampling_ratio=2
    )
    features = pool(features, [torch.tensor([[0., 0., 0., 0.]])], images.image_sizes)

execute:

$ python test.py
image size: H 100 * W 100
image size: H 100 * W 200
image size: H 100 * W 300
image size: H 100 * W 400
image size: H 100 * W 500
image size: H 100 * W 600
Traceback (most recent call last):
  File "test.py", line 27, in <module>
    features = pool(features, [torch.tensor([[0., 0., 0., 0.]])], images.image_sizes)
  File "/[...]/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/[...]/torchvision/ops/poolers.py", line 231, in forward
    self.setup_scales(x_filtered, image_shapes)
  File "/[...]/torchvision/ops/poolers.py", line 192, in setup_scales
    scales = [self.infer_scale(feat, original_input_shape) for feat in features]
  File "/[...]/torchvision/ops/poolers.py", line 192, in <listcomp>
    scales = [self.infer_scale(feat, original_input_shape) for feat in features]
  File "/[...]/torchvision/ops/poolers.py", line 176, in infer_scale
    assert possible_scales[0] == possible_scales[1]
AssertionError

Expected behavior

Very oblong images should not cause this failure. At the very least, there should be a way to manually set up the pooling scales without relying on heuristics from the initial image.

Environment

$ python collect_env.py 
Collecting environment information...
PyTorch version: 1.8.1+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.16.3

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 11.2.152
GPU models and configuration: GPU 0: GeForce GTX 1650
Nvidia driver version: 460.73.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] efficientnet-pytorch==0.7.0
[pip3] numpy==1.19.2
[pip3] torch==1.8.1
[pip3] torchvision==0.9.1
[conda] Could not collect

Additional context

quasimik commented 3 years ago

An easy workaround is to make sure that the initial image is square-ish, because infer_scale only triggers once per instance.

fmassa commented 3 years ago

The underlying problem in the example is that the image after downsampling by the network has a size of 1 in one of the dimensions (it would have been 0 if it wasn't for padding).

I think we can extend the checks to make sure that having 1 in the scaled size is a dummy value which can be bypassed in the check.

Although I would say that this is a corner case which is in principle bond not to work as expected anyway (the boxes will most certainly not be correct), so maybe I'd be tempted to say that having an error is fine as well, as at least it points to the user that something is wrong, instead of silently working (and giving bad results)