pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.32k stars 6.97k forks source link

[Feature Request] PadToSquare: Square Padding to Preserve Aspect Ratios When Resizing Images with Varied Shapes in torchvision.transforms.v2 #8699

Open geezah opened 1 month ago

geezah commented 1 month ago

๐Ÿš€ The feature

A new transform class, PadToSquare, that pads non-square images to make them square by adding padding to the shorter side. Configuration is inspired by torchvision.transforms.v2.Pad. Note that positional argument size is dropped since we calculate the target size based on the non-square image we want to square pad. This feature would be beneficial in situations where square inputs are required for downstream models or processes, and it simplifies the pipeline by embedding this transformation within torchvision.transforms.v2.

Case 1 (Width > Height): 001463 padded

Case 2: Height > Width: 001087 padded

Case 3: Height == Width: Nothing changes :-)

Image Sources: VOC2012

Motivation, pitch

Iโ€™m working on a multi-label classification project that requires images to be square, but the input dataset has a variety of shapes and aspect ratios. PadSquare would streamline the preprocessing pipeline by automatically resizing these images to square while allowing flexible padding modes. This avoids distortions when resizing further and simplifies handling various image shapes. This feature request is based on the need to make square inputs straightforward and robust with consistent padding.

Alternatives

I have considered using existing padding methods within torchvision, but they require additional logic to conditionally apply padding only to the shorter side, making the code less modular, e.g. as demonstrated in this discussion. Current alternatives involve manually calculating padding and applying it to achieve square shapes. By having a dedicated PadSquare transform, it would streamline this common operation into a more reusable and convenient utility.

Additional context

The PadSquare class uses the _get_params method to calculate the necessary padding values, ensuring the padded image is centered. It also supports multiple padding modes and allows for a specified fill value when using 'constant' mode. It would enhance the versatility of torchvision.transforms.v2 by providing a reusable utility for data preprocessing. Let me know what you think of it! :-)

Initial Implementation

My initial implementation of PadSquare is inspired by the implementation of Pad.

from typing import Any, Dict, List, Literal, Union, Type

import torchvision.transforms.v2.functional as F
from torchvision.transforms import v2
from torchvision.transforms.v2._utils import (
    _check_padding_mode_arg,
    _get_fill,
    _setup_fill_arg,
    _FillType,
)

class PadSquare(v2.Transform):
    """Pad a non-square input to make it square by padding the shorter side to match the longer side.
    Args:
        fill (number or tuple or dict, optional): Pixel fill value used when the  ``padding_mode`` is constant.
            Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
            Fill value can be also a dictionary mapping data type to the fill value, e.g.
            ``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
            ``Mask`` will be filled with 0.
        padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric.
            Default is "constant".

            - constant: pads with a constant value, this value is specified with fill

            - edge: pads with the last value at the edge of the image.

            - reflect: pads with reflection of image without repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
              will result in [3, 2, 1, 2, 3, 4, 3, 2]

            - symmetric: pads with reflection of image repeating the last value on the edge.
              For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
              will result in [2, 1, 1, 2, 3, 4, 4, 3]

    Example:
        >>> import torch
        >>> from torchvision.transforms.v2 import PadSquare
        >>> rectangular_image = torch.randint(0, 255, (3, 224, 168), dtype=torch.uint8)
        >>> transform = PadSquare(padding_mode='constant', fill=0)
        >>> square_image = transform(rectangular_image)
        >>> print(square_image.size())
        torch.Size([3, 224, 224])
    """

    def __init__(
        self,
        fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
        padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
    ):
        super().__init__()

        _check_padding_mode_arg(padding_mode)

        if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
            raise ValueError(
                "`padding_mode` must be one of 'constant', 'edge', 'reflect' or 'symmetric'."
            )
        self.padding_mode = padding_mode
        self.fill = _setup_fill_arg(fill)

    def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
        # Get the original height and width from the inputs
        orig_height, orig_width = v2.query_size(flat_inputs)

        # Find the target size (maximum of height and width)
        target_size = max(orig_height, orig_width)

        if orig_height < target_size:
            # Need to pad height
            pad_height = target_size - orig_height
            pad_top = pad_height // 2
            pad_bottom = pad_height - pad_top
            pad_left = 0
            pad_right = 0
        else:
            # Need to pad width
            pad_width = target_size - orig_width
            pad_left = pad_width // 2
            pad_right = pad_width - pad_left
            pad_top = 0
            pad_bottom = 0

        # The padding needs to be in the format [left, top, right, bottom]
        return dict(padding=[pad_left, pad_top, pad_right, pad_bottom])

    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
        fill = _get_fill(self.fill, type(inpt))
        return self._call_kernel(
            F.pad,
            inpt,
            padding=params["padding"],
            padding_mode=self.padding_mode,
            fill=fill
        )
NicolasHug commented 1 month ago

Thanks a lot for the super detailed feature request @geezah !

This sounds reasonable but before we move towards a PR, can you help me understand why you think padding is preferable to resizing the input image here?

Side note: using query_size(flat_inputs) as suggested in the snippet above will enforce that all images in the input are of the same original shape. I don't think we can avoid such enforcement (at least not easily), but I just wanted to point that out in case that's not desirable for your own use-case.

geezah commented 1 month ago

Thanks for the feedback! The padding approach was suggested mainly for cases where preserving aspect ratios could be beneficial, such as:

You're right about the issue with the same-sized inputs. For handling variable input sizes, one could implement a custom collate_fn that performs random resizing at batch creation time instead of during the transform pipeline. This would allow for more flexibility while maintaining batch efficiency.

NicolasHug commented 1 month ago

Thank for coming back to me @geezah . This sounds good, please feel free to submit a PR! Let's go with the simple approach of using query_size first, we can consider the collate_fn approach later if needed.

geezah commented 1 month ago

Alright ๐Ÿ˜„ Thank you for coming back to it quickly!