pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.21k stars 6.95k forks source link

[FEEDBACK] TransformsV2: What may change in the future (we need your input!) #7319

Closed NicolasHug closed 9 months ago

NicolasHug commented 1 year ago

The goal of this issue is two-fold:

We'll detail each of those topics below. Please share any feedback or suggestion you may have to help us provide the most useful APIs!

Subclass (un)wrapping

All tensor operations on a datapoint currently lose the datapoint class and return a pure tensor instead. We call this mechanism "subclass unwrapping":

img1 = datapoints.Image(torch.rand(3, 224, 224))
img2 = datapoints.Image(torch.rand(3, 224, 224))
img3 = img1 + img2  # img3 is not an Image tensor anymore, it's a pure tensor!
assert isinstance(img3, torch.Tensor) and not isinstance(img3, datapoints.Image)

# The same is true for datapoints.Video, datapoints.BoundingBox, datapoints.Mask, etc.

The reason we currently unwrap the datapoints is because some of them (e.g. bouding boxes) come with extra meta-data attached to them like the bbox format, and there is currently no protocol to pass that meta-data down to the output result. The second reason is that in some cases it's impossible to know whether the result of the operation is still a valid datapoint: in the example above, can we still consider img3 to be a valid Image?

We acknowledge that this unwrapping behaviour may seem surprising and unexpected in some cases. E.g. for datapoints that don't have meta-data (hence the first reason doesn't apply here), once could argue that it's up to the user to decide whether the datapoint is still valid or not; following that argument, we could potentially always return Images or Videos since they don't (currently) have any meta-data. We could also think of way to "force subclass wrapping", e.g. through a context manager like

with force_subclass_wrapping():
    img3 = img1 + img2
    assert isinstance(img3, torch.Tensor) and isinstance(img3, datapoints.Image)

Let us know what you think!

Bounding box clamping

Currently, all transforms that may potentially operate on a bounding box will automatically clamp that bounding box to its corresponding image dimensions. Whether the transforms should clamp or not clamp by default is up for discussion. We could also let users choose (by adding a new parameter to all of those tranfsorms?)

Enforce a single BoundingBox instance in all transforms?

Right now, some transforms allow for multiple BoundingBox instances to be present in the input samples, while others will raise an error. We may consider enforcing one unique BoundingBox instance for all transforms in the future.

(Note that a single BoundingBox instance may still contains multiple bounding boxes!)

How to handle labels?

Sometimes, some bounding boxes become degenerate after a tranformation and we need a way to remove them, along with their associated labels.

Labels are tricky because they can refer to different things: an image, a bouding box, or a mask. In a previous design we had a special Label datapoint subclass but we decided to not release it for now, because of the ambiguity of what they should refer to: if we have a sample like img, bbox, label, how do we know whether the label is for the image, or for the bounding box?

For this reason we have currently decided to not have a Label datapoint class, and instead let labels be pure tensors (or ints) and pass them through all of the transforms. The only transform that can handle label is SanitizeBoundingBox, which asks users to manually specify which entries in the input correspond to the labels: so there's no need to guess anymore, and no ambiguity.

We're still considering changing this and potentially bring back a Label subclass (this related to another point in this issue about pairwise transforms which may need a Label subclass).

We're also considering the alternative of not having a Labels subclass, and instead let the label be a meta-data attached to the datapoints: e.g. the Image class could have a label meta-data, and so could the BoundingBox class.

Your input on the subject would be valuable.

How to smoothly support "pairwise" transforms?

There are a few critically useful transforms that operate on pairs of samples instead of operating on a single sample: CutMix, MixUp, etc. Because of their fundamentally different behaviour, they tend to be (and currently are) implemented as collation function to be passed to the DataLoader, and so they cannot be used like the rest of the transforms, which makes them harder to use.

Those transforms also need to tranform the labels on top of the input images, and we're still trying to figure out the smoothest way to handle labels (see other point in this issue).

For these reasons we have currently left those in the prototype area as we're still aiming to polish their APIs.

One option we are dicussing (but it is far from finalized) is to implement those transforms as stateful transforms, to allow them to be used like regular transforms. Something roughly along those lines:

class MixUp():
    def forward(self, img, label):
        out = _mixup_pair(img, label, self._prev_img, self._prev_label)
        self._prev_img, self._prev_label = img, label
        return out

Whether this is a good or a terrible idea is still up for discussion!

Supporting user-defined datapoints and datapoints methods

Users can already implement custom transforms that are compatible with transforms V2. Implementing a user-defined datapoint is also supported, but we're not too happy with the way we currently enable that support. To enable custom datapoints, we currently override a lot of the transforms as methods on the datapoints classes, e.g. the Image class has all of the .resize(), .crop(), .rotate() methods, etc.

This isn't something we're too happy with because it makes the implementation of new transforms cumbersome, and it may also conflict with the Tensor base-class namespace.

We do not guarantee that we'll keep supporting those methods in the future.

Tensor pass-through heuristic

At this time, inputs that aren't datapoints will be passed-through all transforms:

transformed_img, other_stuff = t(img, other_stuff)
# other_stuff is passed-through!

Well, not all non-datapoints inputs are passed-through: we still want the transforms V2 to be fully backward compatible with the V1 transforms, so we still want to treat pure tensors as Images.

For this reason we currently have implemented a (potentially surprising) heuristic:

If this is confusing don't worry, 99% of users don't need to worry about this anyway.

We're considering ways to simplify of even remove this heuristic, e.g. in https://github.com/pytorch/vision/pull/7340

NicolasHug commented 1 year ago

Here's an update on the items above, as of v0.16:

Subclass (un)wrapping

The default behaviour hasn't changed: operations on TVTensors (renaming from Datapoint) always "unwrap" to a pure tensor. But we have added a set_return_type() context manager / global config flag for those who want to return TVTensors instead. Read more here

Bounding box clamping

No changes, all geometrical BBox operation will clamp.

Enforce a single BoundingBox instance in all transforms?

All transforms assume that a single BoundingBoxes instance is present, but it's not really enforced either. Results are undefined if there's more. Also, BoundingBoxes objects don't accept arbitrary leading batch dimension anymore, i.e. their shape is restricted to (num_boxes, 4) (for XYXY format).

How to handle labels?

We don't. Labels classes are still in the prototype area but there is no plan to make them stable for now. For those transforms which need to handle labels like SanitizeBoundingBoxes or CutMix / MixUp, we added a labels_getter parameter with a default heuristic that should catch most use-cases properly (and if not, it can be set to the users need).

How to smoothly support "pairwise" transforms?

We have released CutMix and MixUp in the stable area, with a labels_getter parameter (see above). They're meant to be used on batches, i.e. after the Dataloader. Read more here

Supporting user-defined datapoints and datapoints methods

We have removed the public methods on TVTensors and provided a public interface for users to create their own TVTensors, and register custom kernels to transforms those sub-classes. Read more here

Tensor pass-through heuristic

It hasn't changed and it's vaguely documented here

(but again: it's very advanced usage, most users don't need to care about this)

sklum commented 1 year ago

I'm not sure if this is the correct place to post feedback but I think the clamp-by-default behavior for geometric transforms is unnecessary. Coming to v2 transforms without looking through the tickets, I didn't realize that this was happening because I was using the BoundingBox class; the behavior is pretty opaque. Beyond this, for my use case, I actually don't want to clamp. It's unclear to me if there is a way to use this class and the associated functionality of transforms like RandomIoUCrop without clamping. To me, a more clear delineation of functionality would be to clamp as a transform or perhaps as an argument to SanitizeBoundingBoxes.

NicolasHug commented 1 year ago

Thanks for the feedback @sklum , we're still open to allowing a non-clamping behaviour through a parameter.

for my use case, I actually don't want to clamp

Can I ask what your use-case is and why clamping is not desired?

sklum commented 1 year ago

Sure thing! In my case it's actually interesting to know if the model predicts an object extends out of the image frame (and by how much) and because of the deformability of the class in question it's hard to estimate that from, say, aspect ratios of a bounding box at the edge of the image.

More generally, though, the clamping behavior just feels to me like hidden functionality. It may be the standard in 95+% of use cases, but I think I'd rather know that I was specifying the behavior than have it occur under the hood. To me that's what the transform abstraction represents: explicitly specifying the transformations you want to occur on your data. FWIW I think a transform is more elegant for this than clamp=False arguments to every geometric transform.

I hope this is helpful!

NicolasHug commented 9 months ago

Since the v2 transforms have been released as stable in the new 0.17 release of torchvision, I'll close this issue. @sklum thanks a lot for your feedback and for the feature request, I've opened https://github.com/pytorch/vision/issues/8254 to keep track of that.