pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.15k stars 6.94k forks source link

[feature request] transforms for object detection #3286

Open ydcjeff opened 3 years ago

ydcjeff commented 3 years ago

🚀 Feature

I would like to start adding/supporting transforms (both functional and class) for object detection, I know I can take some of them from references folder. But, it would nice to have OOTB. Here are a few basic transforms I would like to add first -

Pitch

All of the above transforms will accept 2 arguments when they are called. This breaks the purpose of Compose and nn.Sequential, but currently aren't we writing custom Compose or nn.Sequential? So I think it's ok to start introducing necessary transforms taking 2 arguments for detection, segmentation, etc and let users write custom Compose or nn.Sequential the way they would to like to call the transforms.

Additional context

Current code:

class RandomHorizontalFlipWithBBox(nn.Module):
    def __init__(self, prob: float = 0.5):
        super().__init__()
        self.prob = prob

    def forward(self, img, target):
        if random.random() < self.prob:
            width = img.width
            xmin, xmax = target[..., 0], target[..., 2]
            diff = abs(xmax - xmin)
            target[..., 0] = width - xmin - diff
            target[..., 2] = width - xmax + diff
            return FT.hflip(img), target
        return img, target

    def __repr__(self):
        return self.__class__.__name__ + "(p={})".format(self.prob)
class RandomVerticalFlipWithBBox(nn.Module):
    def __init__(self, prob: float = 0.5):
        super().__init__()
        self.prob = prob

    def forward(self, img, target):
        if random.random() < self.prob:
            height = img.height
            ymin, ymax = target[..., 1], target[..., 3]
            diff = abs(ymax - ymin)
            target[..., 1] = height - ymin - diff
            target[..., 3] = height - ymax + diff
            return FT.vflip(img), target
        return img, target

    def __repr__(self):
        return self.__class__.__name__ + "(p={})".format(self.prob)
class LetterBox(nn.Module):
    """
    Make letter box transform to image and bounding box target.

    Args:
        size (int or tuple of int): the size of the transformed image.
    """

    def __init__(self, size: Union[int, Tuple[int]]):
        super().__init__()
        self.size = size
        if isinstance(size, int):
            self.size = (size, size)

    def forward(self, img: Image.Image, target: Union[np.ndarray, Tensor]):
        """
        Args:
            img (PIL Image): Image to be transformed.
            target (np.ndarray or Tensor): bounding box target to be transformed.

        Returns:
            tuple: (image, target)
        """
        old_width, old_height = img.size
        width, height = self.size

        ratio = min(width / old_width, height / old_height)
        new_width = int(old_width * ratio)
        new_height = int(old_height * ratio)
        img = T.functional.resize(img, (new_height, new_width))

        pad_x = (width - new_width) * 0.5
        pad_y = (height - new_height) * 0.5
        left, right = round(pad_x + 0.1), round(pad_x - 0.1)
        top, bottom = round(pad_y + 0.1), round(pad_y - 0.1)
        padding = (left, top, right, bottom)
        img = T.functional.pad(img, padding, 255 // 2)

        if isinstance(target, torch.Tensor):
            target[..., 0] = torch.round(ratio * target[..., 0]) + left
            target[..., 1] = torch.round(ratio * target[..., 1]) + top
            target[..., 2] = torch.round(ratio * target[..., 2]) + right
            target[..., 3] = torch.round(ratio * target[..., 3]) + bottom
        elif isinstance(target, np.ndarray):
            target[..., 0] = np.rint(ratio * target[..., 0]) + left
            target[..., 1] = np.rint(ratio * target[..., 1]) + top
            target[..., 2] = np.rint(ratio * target[..., 2]) + right
            target[..., 3] = np.rint(ratio * target[..., 3]) + bottom
        return img, target

    def __repr__(self):
        return self.__class__.__name__ + f"({self.size})"

Thank you!

cc @vfdev-5, @fmassa

fmassa commented 3 years ago

Hi,

Thanks for opening this issue.

This has been in our radar for a while already, but we never really managed to find out the right balance between simplicity and generality about the API. For example, about the API you proposed, it wouldn't be enough if we wanted to work on image + boxes + keypoints, or even image +segmentation map, so we would need a number of repeated implementations to cover the models in torchvision.

For an earlier attempt for the APIs, see https://github.com/pytorch/vision/issues/1406 and the discussion within.

I would love to hear your thoughts on this.

zhiqwang commented 3 years ago

FYI, It seems that the existing batch_images in GeneralizedRCNNTransform plays the same role as the proposed LetterBox here.

https://github.com/pytorch/vision/blob/51500c7e067f2f92765734c69e2b082c221a2eae/torchvision/models/detection/transform.py#L199-L217