pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
15.79k stars 6.89k forks source link

Keypoint transform #1131

Open gheaeckkseqrz opened 5 years ago

gheaeckkseqrz commented 5 years ago

Hi Pytorch community 😄

I started working on keypoint transformation (as was requested in #523). I worked on it in the context of data augmentation for object detection tasks.

I submitted a proposal in PR #1118, but as @fmassa pointed out, that's not something we can merge without reviewing the design choices.

I've implemented the functionality by changing the signature of the transform.call() method from: def __call__(self, img): to def run(self, img, keypoints): so that every transform can work on a list of keypoints in addition to the image itself.

I've been with keypoints as a point is the most basic element, bounding boxes are defined as points, segmentation mask can be defined as points, facial landmarks are keypoints ... If we have the ability to transform a point, we have the ability to transform anything.

My goal with that design was to make the data augmentation as straitforward as possible. I added a wrapper class to transform the XML annotaion from VOCDetection to a keypoint list and fed then to the transform pipeline.

class TransformWrapper(object):
    def __init__(self, transforms):
        super(TransformWrapper, self).__init__()
        self.transforms = transforms
        pass

    def __call__(self, img, anno):

        print(img, anno)

        keypoints = []
        objs = anno['annotation']['object']
        if not isinstance(objs, list):
            objs = [objs]
        for o in objs:
            b = o['bndbox']
            x1 = int(b['xmin'])
            x2 = int(b['xmax'])
            y1 = int(b['ymin'])
            y2 = int(b['ymax'])
            keypoints.append([x1, x2])
            keypoints.append([y1, y2])
        img, keypoints = self.transforms(img, keypoints)
        for o in objs:
            b = o['bndbox']
            x = keypoints.pop(0)
            b['xmin'] = str(int(x[0]))
            b['xmax'] = str(int(x[1]))
            y = keypoints.pop(0)
            b['ymin'] = str(int(y[0]))
            b['ymax'] = str(int(y[1]))
        return img, anno

This allows for an usage as simple as

transform = transformWrapper.TransformWrapper(torchvision.transforms.Compose([torchvision.transforms.Resize(600), torchvision.transforms.ToTensor()]))
vocloader = torchvision.datasets.voc.VOCDetection("/home/wilmot_p/DATA/", transforms=transform)

And the annotations comes out with values corresponding to the resized image.

The aim of this thread is to bring up other usecases of keypoint transformation that I may not have though of and that may be imcompatible with this design, so that we can make a sensible design decision that works for everyone. So if you have an oppinion on this matter, please share 😄

Curently, one of the drawbacks of my design is that I broke the interface for Lambda, it use to take only the image as input parameter, it now takes the image and the keypoint list, and that break retro-compatibility.

fmassa commented 5 years ago

Thanks for opening this issue and the proposal!

I have a question about your proposal: what if we want to have bounding boxes and keypoints as the target for our model, for example in Keypoint R-CNN as in https://github.com/pytorch/vision/blob/bbd363ca2713fb68e1e190206578e600a87baf90/torchvision/models/detection/keypoint_rcnn.py#L20-L100

I believe we would need to extend the function call to take another argument.

And if we want to also have masks at the same time, we would be in the business of adding yet another argument. Or, we could find a way to make this support a (potentially) arbitrary number of elements to transform.

The problem with those generic approaches is that we sometimes do not want to apply all transforms to all data types. For example, we only want to add color augmentation to the image, not to the segmentation map.

This has been extensively discussed in the past, see https://github.com/pytorch/vision/issues/9 https://github.com/pytorch/vision/issues/230 https://github.com/pytorch/vision/issues/533 (and the issues linked therein), and the conclusion for now has been that all the alternatives that have been proposed are either too complicated, or not general enough. For that reason, the current recommended way of handling composite transforms is to use the functional interface, which is very generic and gives you full control, at the expense of a bit more code. We have also recently improved the documentation in https://github.com/pytorch/vision/pull/602 in order to make this a bit more clear.

Thoughts?

The things

gheaeckkseqrz commented 5 years ago

I have a question about your proposal: what if we want to have bounding boxes and keypoints as the target for our model, for example in Keypoint R-CNN as in

This shouldn’t be a problem, bounding boxes are usually encoded as key points (top-left & bottom right). You can pass those keypoints to the transforms in order to get the corresponding corner for the boxes. This is easily done using a wrapper like the one in the first message.

As I see it, the role of a transform is really just to map a point to another point (single responsibility principle), it’s not supposed to be a full bridge between the dateset object and the model.

At the end of the day, everything is a point, bounding boxes are defined by their corners which are points, images are really just a grid of WxH points, and masks are either an image or a set of keypoints. The only thing that are not points are labels, but they don’t get transformed.

The problem right now is that if you try to use a dataset like VOCDetection and try to augment your data with RandomRotation or RandomPerspective, you get a rotated/squished image, but there is currently no way to transform the bounding boxes, as the transformation parameters are lost once you exit the transform function.

I think the design I presented needs to be updated to take a list of images and a list of key points rather than just an image and a list of keypoints, as sometime what you want is to apply the same transform to multiple image (image & segmentation mask pairs).

ColorJitter is a bit an edge case as it’s not remapping pixels position but rather modifying their values. Not sure how to solve this one.

I was thinking about adding a type check so that since image are encoded as float tensor and mask should usually be more of a int tensor, we could use that information to decide to apply the transform or not, but that would break in the case where we get some pair of (scene, depth/heat map).

Another solution would be to add a third parameter so that each transform would receive ([images, ], [mask_images,], [keypoints,]), but that mean adding a third parameter in the interface just to work around the special case of ColorJitter.

fmassa commented 5 years ago

This shouldn’t be a problem, bounding boxes are usually encoded as key points (top-left & bottom right). You can pass those keypoints to the transforms in order to get the corresponding corner for the boxes. This is easily done using a wrapper like the one in the first message.

This means that you need to wrap your box + keypoint in another data structure, and perform the unwrapping inside the transform wrapper. This requires almost as much code as the current functional interface I believe.

The problem with the current set of transformations is that you always flip the image as well, even if all what you want is to flip the keypoints (which should be much cheaper than flipping the image). Ideally, those methods should be decoupled, so that one can perform transformations on those data-structures alone. This means that the keypoints need to know the width / height of the image.

One of the solutions that I have proposed in the past was to have some boxing abstractions, like torchvision.Image, torchvision.Keypoint, torchvision.Mask etc, and each one of those abstractions have everything they need internally to transform itself. This might be one way of handling those different edge-cases, where torchvision.Mask implementation of CollorJitter returns the identity for example.

And then, we can have a composite class torchvision.TargetCollection, which is a group of any of the aforementioned objects, and calling for example target_collection.resize((300, 300)) propagates the resize to all its constituent elements (which can be images, keypoints, boxes, masks, etc).

Thoughts?

gheaeckkseqrz commented 5 years ago

This means that you need to wrap your box + keypoint in another data structure, and perform the unwrapping inside the transform wrapper. This requires almost as much code as the current functional interface I believe.

That's right, the idea was to ship the wrapper along with the dataloader, so that for the end user, it just results in a couple of lines.

I'm curently trying to re-implement the YOLO paper for learning purpose, and that's what my dataloading/data augmentation setup looks like:

transform = transformWrapper.TransformWrapper(torchvision.transforms.Compose([
    torchvision.transforms.Resize(512),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomRotation(10),
    torchvision.transforms.RandomPerspective(distortion_scale=.1, p=1),
    torchvision.transforms.RandomCrop(448),
    torchvision.transforms.ColorJitter(.2, .2, .2, .2),
    torchvision.transforms.ToTensor()]))
vocloader = torchvision.datasets.voc.VOCDetection("/home/wilmot_p/DATA/", transforms=transform)

The problem with the current set of transformations is that you always flip the image as well, even if all what you want is to flip the keypoints (which should be much cheaper than flipping the image). Ideally, those methods should be decoupled, so that one can perform transformations on those data-structures alone. This means that the keypoints need to know the width / height of the image.

I strongly agree with the fact that the keypoints and images transforms should be decoupled. But that means we need a system to share the random parameters. I've seen descution about re-seeding the RNG before every transform, and even though that would technicaly work, it feels like bad software design. If we decide to go the decoupled way, the biggest problem we have to solve is RNG synchronisation.

One of the solutions that I have proposed in the past was to have some boxing abstractions, like torchvision.Image, torchvision.Keypoint, torchvision.Mask etc, and each one of those abstractions have everything they need internally to transform itself. This might be one way of handling those different edge-cases, where torchvision.Mask implementation of CollorJitter returns the identity for example.

And then, we can have a composite class torchvision.TargetCollection, which is a group of any of the aforementioned objects, and calling for example target_collection.resize((300, 300)) propagates the resize to all its constituent elements (which can be images, keypoints, boxes, masks, etc).

This looks like it should work, and I like the idea introducing proper types, just means a lot more code to write. I'll try to come up with a proof of concept over the weekend, to see how that compares to my earlier proposal in terms of ease of use for the end user 😄

qinjian623 commented 3 years ago

Hi, any updates here?

@gheaeckkseqrz

Hi, is that possible providing points transformation as a individual repo, so I can import it as a single library? That would be a fast-path for users.