pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.13k stars 6.94k forks source link

[RFC] Abstractions for segmentation / detection transforms #1406

Open fmassa opened 5 years ago

fmassa commented 5 years ago

This is a proposal. I'm not sure yet it's the best way to achieve this, so I'm putting this up for discussion.

tl;dr

Have specific tensor subclasses for BoxList / SegmentationMask, CombinedObjects , etc, which inherits from torch.Tensor and overrides methods to properly dispatch to the relevant implementations. Depends on __torch_function__ from https://github.com/pytorch/pytorch/issues/22402 (implemented in https://github.com/pytorch/pytorch/pull/27064)

Background

For more than 2 years now, users have asked for ways of performing transformations of multiple inputs at the same time, for example for semantic segmentation or object detection https://github.com/pytorch/vision/issues/9

The recommended solution is to use the functional transforms in this case https://github.com/pytorch/vision/issues/230 https://github.com/pytorch/vision/issues/1169 , but for simple cases, this is a bit verbose.

Requirements

Ideally, we would want the following to be possible:

  1. work with a Compose style interface for simple cases
  2. support more than a single input of each type (for example, two images and one segmentation mask)
  3. support joint rotations / rescale with different hyperparameters for different input types (images can do bilinear interpolation, segmentation maps should do nearest interpolation)
  4. be simple and modular

Proposed solution

We define new classes for each type of object, which should all inherit from torch.Tensor, and implement / override a few specific methods. It might depend on __torch_function__ from https://github.com/pytorch/pytorch/issues/22402 (implemented in https://github.com/pytorch/pytorch/pull/27064)

Work with Compose-style

We propose to define a CombinedObjects (better names welcome), which is a collection of arbitrary objects (potentially named, but that's not a requirement). Calling any of the methods in it should dispatch to the corresponding methods of its constituents. A basic example is below, I'll mention a few more points about it afterwards):

class CombinedObjects(object):
    def __init__(self, **kwargs):
        self.kwargs = kwargs

    def hflip(self):
        result = {}
        for name, value in self.kwargs.items():
            result[name] = value.hflip()
        return type(self)(**result)

In this way, if the underlying objects follows the same protocol (i.e., implement the required functions), then this should allow to combine an arbitrary number of objects with a Compose API via

# example for flip
class RandomFlip(object):
    def __call__(self, x):
        # implementation stays almost the same
        # but now, `x` can be an Image, or a CombinedObject
        if random.random() > 0.5:
            x = x.hflip()
        return x

transforms = Compose([
    Resize(300),
    RandomFlip(),
    RandomColorAugment()
])

inputs = CombinedObjects(img1= x, img2=y, mask=z)
output = transforms(inputs)

which satisfies point 1 and 2 above, and part of point 3 (except for the different transformation hyperparameters for image / segmentation mask, which I'll cover next).

Different behavior for mask / boxes / images

In the same vein as the CombinedObject approach from the previous section, we would have subclasses of torch.Tensor for BoxList / SegmentationMask / etc which would override the behavior of specific functions so that they work as expected.

For example (and using code snippets from https://github.com/pytorch/pytorch/pull/25629), we can define a class for segmentation masks where rotation / interpolation / grid sample always behave with nearest interpolation:

HANDLED_FUNCTIONS = {}

def implements(torch_function):
    "Register an implementation of a torch function for a Tensor-like object."
    def decorator(func):
        HANDLED_FUNCTIONS[torch_function] = func
        return func
    return decorator

class SegmentationMask(torch.Tensor):
    def __torch_function__(self, func, types, args, kwargs):
        if func not in HANDLED_FUNCTIONS:
            return NotImplemented
        # Note: this allows subclasses that don't override
        # __torch_function__ to handle DiagonalTensor objects.
        if not all(issubclass(t, self.__class__) for t in types):
            return NotImplemented
        return HANDLED_FUNCTIONS[func](*args, **kwargs)

@implements(torch.nn.functional.interpolate)
def interpolate(...):
    # force nearest interpolation
    return torch.nn.functional.interpolate(..., mode='nearest')

and we can also give custom implementations for bounding boxes:

class BoxList(torch.Tensor):
    # need to define height and width somewhere as an attribute
    ...

@implements(torch.nn.functional.interpolate)
def interpolate(...):
    return box * scale

This would allow to cover the remaining of point 3. Because they are subclasses of torch.Tensor, they behave as Tensor except in some particular cases where we override the behavior. This would be used as follows:

boxes = BoxList(torch.rand(10, 4), ...)
masks = SegmentationMask(torch.rand(1, 100, 200))
image = torch.rand(3, 100, 200)
# basically a dict of inputs
x = CombinedObject(image=image, boxes=boxes, masks=masks)
transforms = Compose([
    RandomResizeCrop(224),
    RandomFlip(),
    ColorAugment(),
])
out = transforms(x)
# have an API for getting the elements back
image = out.get('image')
# or something like that

Be simple and modular

This is up for discussion. The fact that we are implementing subclasses that do not behave exactly like tensors can be confusing and misleading. But it does seem to simplify a number of things, and makes it possible for users to leverage the same abstractions in torchvision for their own custom types, without having to modify anything in torchvision, which is nice.

Related discussions

Some other proposals have been discussed in https://github.com/pytorch/vision/issues/230 https://github.com/pytorch/vision/issues/1169 and many other places.

cc @Noiredd @SebastienEske @pmeier for discussion

SebastienEske commented 5 years ago

Hello, If I understand correctly, this proposal implies to

I am not sure that combining transforms and the object being transformed is a good idea. I understand the necessity of applying the same transform with different interpolation methods but while in some generic cases, it may be easy to predict, in some corner cases, it may not work.

For example, if we assume bilinear interpolation for an image and then someone comes up and says that he prefers cubic interpolation. How does that work? Do we add yet another parameter for that? But this is a transform parameter in a Tensor-like object...

I am not sure that you saw my last (heavily edited) comment in #1315. I think it solves the issue at hand

  1. without too much rework,
  2. with keeping backward compatibility,
  3. with keeping transforms and transformed objects separated,
    • and it meets all the above objectives 1, 2, 3 and 4 (objective 2 is met with a different formulation "being able to apply the exact same transform to several inputs")

Here is a more detailed version of it and adapted to meet all the above requirements. It only needs to update the existing transforms by adding:

The example below uses a dictionary of parameters for the constructor to be generic but it could be implemented with the normal "list" of parameters which can then be put in an internal dict anyway. This maintains backward compatibility.

We could have a master argument when creating a transform that would indicate the transform object to reuse the state from the parent.

affine_augment = tv.transforms.RandomAffine({'degrees':90, 'translate':(-50, 50), 'scale':(0.5, 2)})
affine_augment_label = tv.transforms.RandomAffine({'interpolation':'nearest'}, master=affine_augment)
image = affine_augment(image)
label = affine_augment_label(label)

And inside the transform we could have something like this:

def __init__(params, master=None)
    self.master = master
    self.async_params = []  # parameters that do not follow master transform
    if master is not None:
        master_params = master.get_params()
        for param_name in master_params.keys():
            self.params[param_name] = master_params[param_name]
    for param_name in params.keys:
         self.params[param_name] = params[param_name]
         self.async_params.append(param_name)

def __call__(input, params=None):
    if self.master is not None:
        master_params = master.get_params()
        for param_name in master_params.keys():
            if param_name not in self.async_params:
                self.params[param_name] = master_params[param_name]
    self.update_params() # update self.async_params
    temp_params = self.params
    if params is not None:
        for param_name in params.keys:
            temp_params[param_name] = params[param_name]
    return apply_params(input, temp_params)

def get_params()
    return self.params

This also allows customized transforms in the training loop:

affine_augment = tv.transforms.RandomAffine({'degrees':90, 'translate':(-50, 50), 'scale':(0.5, 2)})
# update params and apply transform
image = affine_augment(image)
#get current params
params = affine_augment.get_params()
params['interpolation'] = 'nearest'
#do not use updated params, apply the provided ones instead
label = affine_augment(label, params)
SebastienEske commented 5 years ago

PS: if you absolutely want to have one transform taking several inputs, we could have a BundleTransform with the above proposed transforms (or any other transform for that matter)

class BundleTransform(object):
    def __init__(self, *transforms):
        self.transforms = transforms

    def __call__(*inputs):
        if len(inputs) == len(self.transforms):
            return [self.transforms[i](inputs[i]) for i in range(len(self.transforms))]

affine_augment = tv.transforms.RandomAffine({'degrees':90, 'translate':(-50, 50), 'scale':(0.5, 2)})
affine_augment_label = tv.transforms.RandomAffine({'interpolation':'nearest'}, master=affine_augment)

affine_augment_bundle = BundleTransform(affine_augment, affine_augment_label)
image, label = affine_augment_bundle(image, label)
Noiredd commented 5 years ago

Previously the transform objects were responsible for actually performing the transformation. You propose to move this logic to the objects that hold data. In your pipeline, Compose executes a Transform on some object representing data; Transform, instead of directly executing a function from tv.transforms.functional, now calls a member function of the received object. If that object happens to be a ComposedObject, it calls a member function on each of its members. Finally, a member function of the data-holding object is executed. This function performs the actual transformation; it may use things from functional or elsewhere.

This means that the transformation logic would be attached to the data objects. In the end, I think, there is no other way to solve this in a general way. After all, it's the user who knows what their data means and what does it mean to "transform it". I think that makes sense - the last code snippet in your post made me smile, this is exactly what I need!

What I'm worried with, however, is the resulting complexity on the back end of this approach. Two questions here:

Honestly, the concern stems from this whole __torch_function__ thing that I can't say I understand exactly. In order to write my own transforms, will I have to?

PS: Are BoxList and SegmentationMask classes that would be provided by torchvision, or would the user have to write them themselves?


@SebastienEske

For example, if we assume bilinear interpolation for an image and then someone comes up and says that he prefers cubic interpolation.

If the object-oriented transformation API stays the same, you can just construct your transform object with cubic interpolation. As I understand the above, nothing changes about that.

Your master-follower scheme would allow constructing two separate Compose objects that would consist of linked transforms. Synchronization is indeed stronger than in the RNG case, but it's still implicit - once the transform objects are created, the link becomes hidden in the lower layer of abstraction (and thus harder to track, either in case something goes wrong, or if we want to alter it on the run).

There seem to be two ways to go: either a small patch to existing transforms, or a total makeover as proposed here. Of the two most reasonable proposals for the former, param-based (#1315) and master-follower (comment), the only difference is basically who passes the params to whom: is it the Compose object, or do transforms do it themselves. I've been pondering on this for a while, trying to name the thing that just doesn't feel right to me about this, and I come to the conclusion that I think this operations is not good OOP. Exchange of parameters belongs to a higher level of abstraction; transforms should just do what they're explicitly told (especially that you allow this path in your latest proposal), not talk among each other without any managing entity knowing. Of course this is just my opinion.


To sum up, I don't dislike this proposal - I'm just cautious about it. The question that's important to me is: what will be simpler an average PyTorch user: writing my own Compose and do get_params on existing transforms, #1315-style, or implement a transform and go through all that __torch_function__ stuff? Including the mental exercise of understanding either approach.

fmassa commented 5 years ago

Thanks @SebastienEske and @Noiredd for your feedback!

I agree with @Noiredd points about the master-worker approach proposed by @SebastienEske (and the param-based approach) , in particular

I think this operations is not good OOP. Exchange of parameters belongs to a higher level of abstraction; transforms should just do what they're explicitly told (especially that you allow this path in your latest proposal), not talk among each other without any managing entity knowing.

To answer @Noiredd questions

How difficult would it be to implement my own transform?

Great question, and I see two different cases here.

If you live out-of-tree (i.e., not modifying torchvision), I would imagine that the user would implement a new transform by handling himself the different input types:

class MyNewTransform(object):
    def __call__(self, x):
        if isinstance(x, CombinedObject):
            out = []
            for y in x:
                out.append(self(y))
            return out
        elif isinstance(x, BoxList):
            # do something here
        elif isinstance(x, SegmentationMask):
            # something lese
        else:
            # finally

This doesn't require monkey-patching the torchvision objects to support your new transform, and the user controls what they want to support in a very explicit way.

If your code lives in torchvision, you could extend methods to each class that you care about so that this new transform is transparently handled.

How difficult would it be to implement my own data type?

I think this should be fairly simple. A cousin approach of what I proposed is what I did in maskrcnn-benchmark in the structures folder, where we have different classes for each object, and each one follows a common API (which was inspired by PIL API, see for example resize for boxes, for masks and for keypoints). So if we follow that approach for now (to simplify things), it would involve doing something like

class Mesh(object):
    def resize(self, ...):
        pass
    def transpose(self, ...):
        pass
    ...

Now, why don't we just do this and try not to mess with __torch_function__ (which doesn't exist yet btw)? Well, in many cases, it is useful to have everything be a torch.Tensor. You can perform many operations on them, without having to unwrap / rewrap everything all the time. The __torch_function__ approach would simplify some cases for dispatching to specific functions (like interpolate or grid_sample), but I think we can try to keep things simpler as of now (specially that __torch_function__ isn't available yet and I'd need to play a bit with it).

Are BoxList and SegmentationMask classes that would be provided by torchvision, or would the user have to write them themselves?

torchvision would provide those classes

pmeier commented 5 years ago

Sorry for the late reply, but at this point I don't have much to add. Also, I think my understandig of the bigger picture is not clear enough to give reasonable suggestions. However, if you need help implementing (even for experimenting) anything, feel free to CC me again.

SebastienEske commented 5 years ago

Hello,

I am ok with whatever approach you choose to use but, considering the amount of change, it's probably good to be aware of the limitations that the one your are proposing has. Allow me to play the devil's advocate one more time ;-):

That is a lot more complicated (and bug prone) than just passing a parameter to an existing function.

If your code lives in torchvision, you could extend methods to each class that you care about so that this new transform is transparently handled.

So if someone wants to contribute a new transform, he/she will need to implement it for all the classes? Even with a single implementation in nn.functional plus reuse in other classes, I am not sure that this is very community friendly.

I also realized that part of the transform is implemented in the content class SegmentationMask for instance for hflip and then another part is implemented in the transform class RandomFlip. I guess that can be a design choice but then we need a clear guideline to say what should be implemented in the content class and what should be implemented in the transform class.

This means that the transformation logic would be attached to the data objects. In the end, I think, there is no other way to solve this in a general way. After all, it's the user who knows what their data means and what does it mean to "transform it".

Users already do it with transforms and they are not attached to the data objects. Attaching the transforms to the data objects gives less flexibility and will inevitably be less generic.

@Noiredd mentioning bad/good OOP got me thinking about it ;-) While I understand that having a supervisor for the synchronization could be nice, I don't understand what would make not having one bad? Do you have any example?

To me, having transforms manage their dependencies on their own is a lightweight approach. The only reason I can think of for having a supervisor would be in the case of distributed training. But even then, all the inputs would probably be loaded on the same GPU so it is not needed. If it were needed I would expect it to be because of memory constraints in which case the class CombinedObjects would be even less flexible as it just would not load in memory.

I am far from being any expert for that but if I understand correctly, could putting the transforms in the data container violate the single responsibility principle. The new classes would be responsible for both storing/representing the data and transforming it...

The need to modify every class to introduce a new transform feels like breaking the open-closed principle (the closed part of it) as well.

PS: I'll try to think of something nicer for the master/slave approach by tomorrow.

SebastienEske commented 5 years ago

Well... I may have a better solution than the master/slave. Only I am surprised that it was not proposed before. Maybe it was and there is something wrong with it. I guess you'll tell me. Why not allow a transform to take lists of parameters and inputs?

# nbtrans is the number of transforms/inputs that the user wants to define
def __init__(param1, param2, ..., nbtrans=None)
    # get value of nbtrans
    self.nbtrans = nbtrans
    # check the other inputs and infer nbtrans value from them if needed
    if isinstance(param1, list):
        if self.nbtrans == None:
            self.nbtrans = len(param1)
        elif self.nbtrans != len(param1):
            raise ValueError('inconsistent number of parameters')
    if isinstance(param2, list):
        if self.nbtrans == None:
            self.nbtrans = len(param2)
        elif self.nbtrans != len(param2):
            raise ValueError('inconsistent number of parameters')
    ...
    if self.nbtrans == None:
        self.nbtrans = 1

    # create the proper set of parameters
    if isinstance(param1, list):
        self.param1 = param1
    else:
        self.param1 = [param1 for i in range(self.nbtrans)]
    if isinstance(param2, list):
        self.param2 = param2
    else:
        self.param2 = [param2 for i in range(self.nbtrans)]
    ...

def __call__(input):
    if isinstance(input, list):
        if len(input) = self.nbtrans:
            # update internal state if needed
            ...
            # compute the transformation for each input/set of parameters
            return [apply_trans(input[i], i) for i in range(len(input))]
        elif self.nbtrans == 1:
            # update internal state if needed
            ...
            # compute the transformation for each input/set of parameters
            return [apply_trans(input[i], 0) for i in range(len(input))]
        else:
            raise ValueError('wrong number of inputs')
    else:
        # should we throw an exception if self.nbtrans>1?
        # update internal state if needed
        ...
        return apply_trans(input, 0)

If a parameter was originally a list we can simply test if it has become a list of lists.

And we don't need to rewrite transforms.Compose. The initial example becomes:

boxes = torch.rand(10, 4)
masks = torch.rand(1, 100, 200)
image = torch.rand(3, 100, 200)
transforms = Compose([
    RandomResizeCrop(224, interpolation=['coords', 'nearest', 'bilinear']),
    RandomFlip(input_type=['coords', 'image', 'image']),
    ColorAugment(input_type=['coords', 'image', 'image']),
])
boxes, mask, image = transforms([boxes, mask, image])

Handling the coordinates is a bit tricky but the input_type parameter and the coords interpolation type solves this problem without using custom classes.

And we don't touch the tensors; the modifications remain in the transforms which I think is really a key point to have. I believe this satisfies all seven criteria above plus the one from @Noiredd about several transforms communicating together without supervision. Now, there is only one transform.

The possible values for input_type should probably be defined somewhere in order to have a unified interface.

What do you think?

SebastienEske commented 5 years ago

Doing some extra thinking. There are at least two ways to resize coordinates: return the exact float values and return the rounded integer values. The first one is akin to bilinear/bicubic interpolation and the second one to nearest neighbor.

So maybe we also want to have the input_type parameter for the RandomResizeCrop...

SebastienEske commented 4 years ago

Hello, it's been a while and since this has not moved on I feel like maybe I voiced my concerns a bit too much. If that's the case, I'm sorry about it.

In any case, my point was to make sure that all concerns and alternatives are presented. The above is all I got so @fmassa whatever you decide to go with will be good for me.

fmassa commented 4 years ago

Hi @SebastienEske

No worries at all, thanks for sharing your thoughts!

I've been slow to respond on this one because I've been busy working on other parts of torchvision. I'm going to be getting back to segmentation / detection transforms most probably in January

vadimkantorov commented 4 years ago

One issue when I stumbled on long ago when reading PASCAL segmentation multiclass masks: https://github.com/pytorch/pytorch/issues/5436

ErezYosef commented 3 years ago

Hi everyone. I want to suggest a solution for the problem: similar to @SebastienEske at https://github.com/pytorch/vision/issues/1406#issuecomment-537775978.

I needed this feature a while ago, so I thought to suggest my solution.

solution:

  1. Super simple to use (and implement!).
  2. work with a Compose style interface
  3. support almost any type of transforms and data types.
  4. Fully backward compatibility

Transform structure:

For example I choose RandomHorizontalFlip. See the changes I added and the following functionallity:

class RandomHorizontalFlip(torch.nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p
        self.track=None                  # added to save the current state

  def forward(self, img, enforce=None):     # added enforce input
        threshold = enforce if enforce is not None else torch.rand(1)     # added usage of last state
        if threshold < self.p:
            self.track = 0                  # save the current state
            return F.hflip(img)
        else:
            self.track = 1                 # save the current state
            return img

Functionallity :

example code:

flip_t = RandomHorizontalFlip(0.5)
img = flip_t(img)

Usage:

Uncorrolated transforms:

img1 = flip_t(img1)
img2 = flip_t(img2)

Same transforms 2 images:

img1 = flip_t(img1)
img2 = flip_t(img2, enforce = flip_t.track)

Same transforms N images:

img_list = [img1, img2, ..., imgN]
img_list[0]= flip_t(img_list[0])
for i in range(1,N):
    img_list[i] = flip_t(img_list[i], enforce = flip_t.track)

Different hyperparameters:

Assume we want to apply RandomRotation on img and mask with different interpolations.

img_t = RandomRotation(degrees, resample=PIL.Image.BILINEAR)
mask_t = RandomRotation(degrees, resample=PIL.Image.NEAREST)
img = img_t(img)
mask = mask_t(mask,, enforce = flip_t.track)

None:

Custom transformation:

Assume img labeled as "right" or "left" by string label (for example there is a right/left arrow in the image). If we flip the image horizontally we should change it's label state too:

def label_custom_flip(label, enforce=None):
      if enforce = 1: return label # No flip
      new_label = 'right' if label=='left' else 'left' # flip the str label
      return new_label

img1 = flip_t(img1) # same as before
label = label_custom_flip(label, enforce = flip_t.track)

Compose:

We can also adapt the compose interface in the same way. In this case:

my_transforms = Compose([...])
img1 = my_transforms(img1)
img2 = my_transforms(img2, enforce = my_transforms.tracks)

Where

---

I hope it is a good solution and you liked my suggestion. I would like to hear your opinion @SebastienEske @fmassa ? Erez.

ErezYosef commented 3 years ago

Continue of my comment above https://github.com/pytorch/vision/issues/1406#issuecomment-777880616 @ErezYosef Implementation of Compose:

class Compose:
    def __init__(self, transforms):
        self.transforms = transforms
        self.tracks = [None]*len(self.transforms)           #added

    def __call__(self, img, enforce=None):
        if enforce is None:                                  # added if
            for t in self.transforms:
                img = t(img)
            return img
        else:                                                 # I add this section:
            for i, t in enumerate(self.transforms):
                img = t(img) if enforce[i] is None else t(img, enforce[i])
                self.tracks[i] = t.track if hasattr(t, 'track') else None
            return img

Please review this implementation also.

Conclusion:

Basically, I aim to change the naive transforms implementation to a more powerful and robust form:

vmoens commented 3 years ago

With respect to:

HANDLED_FUNCTIONS = {}

def implements(torch_function):
    "Register an implementation of a torch function for a Tensor-like object."
    def decorator(func):
        HANDLED_FUNCTIONS[torch_function] = func
        return func
    return decorator

class SegmentationMask(torch.Tensor):
    def __torch_function__(self, func, types, args, kwargs):
        if func not in HANDLED_FUNCTIONS:
            return NotImplemented
        # Note: this allows subclasses that don't override
        # __torch_function__ to handle DiagonalTensor objects.
        if not all(issubclass(t, self.__class__) for t in types):
            return NotImplemented
        return HANDLED_FUNCTIONS[func](*args, **kwargs)

@implements(torch.nn.functional.interpolate)
def interpolate(...):
    # force nearest interpolation
    return torch.nn.functional.interpolate(..., mode='nearest')

and to elaborate a bit more on what this 'function overloading' using __torch_function__ would look like, here is a more comprehensive snippet of what we would see coded:

import torch

HANDLED_FUNCTIONS_IMAGE = {}
HANDLED_FUNCTIONS_SEG = {}

def implements_segmentation(torch_function):
    "Register an implementation of a torch function for a Tensor-like object."
    def decorator(func):
        HANDLED_FUNCTIONS_SEG[torch_function] = func
        return func
    return decorator

def implements_image(torch_function):
    "Register an implementation of a torch function for a Tensor-like object."
    def decorator(func):
        HANDLED_FUNCTIONS_IMAGE[torch_function] = func
        return func
    return decorator

class SegmentationMask():
    def __init__(self, data):
        self.data = data

    def __torch_function__(self, func, types, args, kwargs):
        if func not in HANDLED_FUNCTIONS_SEG:
            return NotImplemented
        if not all(issubclass(t, self.__class__) for t in types):
            return NotImplemented
        return HANDLED_FUNCTIONS_SEG[func](self.data, *args[1:], **kwargs)

class Image():
    def __init__(self, data):
        self.data = data

    def __torch_function__(self, func, types, args, kwargs):
        if func not in HANDLED_FUNCTIONS_IMAGE:
            return NotImplemented
        if not all(issubclass(t, self.__class__) for t in types):
            return NotImplemented
        return HANDLED_FUNCTIONS_IMAGE[func](self.data, *args[1:], **kwargs)

@implements_image(torch.nn.functional.interpolate)
def interpolate_image(*args, **kwargs):
    if 'mode' in kwargs:
        del kwargs['mode']
    return torch.nn.functional.interpolate(*args, **kwargs, mode='linear')

@implements_segmentation(torch.nn.functional.interpolate)
def interpolate_segmentation(*args, **kwargs):
    if 'mode' in kwargs:
        del kwargs['mode']
    return torch.nn.functional.interpolate(*args, **kwargs, mode='nearest')

s = SegmentationMask(
    torch.tensor([1,2,3,4,5,6,7,8,9]).float().view(1,3,3)
)
print('segmentation: ', torch.nn.functional.interpolate(s, 2))

s = Image(
    torch.tensor([1,2,3,4,5,6,7,8,9]).float().view(1,3,3)
)
print('image: ', torch.nn.functional.interpolate(s, 2))

with results

segmentation:  tensor([[[1., 2.],
         [4., 5.],
         [7., 8.]]])
image:  tensor([[[1.2500, 2.7500],
         [4.2500, 5.7500],
         [7.2500, 8.7500]]])
datumbox commented 3 years ago

There is an alternative approach covered here: https://github.com/pmeier/torchvision-datasets-rework/pull/1