pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.13k stars 6.94k forks source link

Reshape/View as a module? #720

Closed pkdogcom closed 5 years ago

pkdogcom commented 5 years ago

I was wondering if there is module which performs reshape/view so that it can be added to nn.Sequential just as other modules like Conv2d or Linear. The reason I want this feature rather than simply performing torch.reshape or tensor.view is that I can make the reshape/view a configurable plugin (especially when combined with global pooling which can be switched on and off) in my model.

varunagrawal commented 5 years ago

You can easily define your own module:

class View(nn.Module):
    def __init__(self, shape):
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

Just plug that into sequential now.

pkdogcom commented 5 years ago

Yes, that's exactly what I'm currently doing. Thanks for confirming that.

fmassa commented 5 years ago

I believe we won't be adding support in nn for View, see https://github.com/pytorch/pytorch/issues/2486 But it's fairly straightforward to create your own module for it.

ruppphil commented 5 years ago

You can easily define your own module:

class View(nn.Module):
    def __init__(self, shape):
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

Just plug that into sequential now.

I implemented this, and it works fine when I add it to my nn.sequential. Unfortunately, when it comes starting the training it gives me the following error: 'View' object has no attribute '_modules' . It seems like something with the inheritance goes wrong?

TheCodez commented 5 years ago

@ruppphil try calling the super constructor:

class View(nn.Module):
    def __init__(self, shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)
fmassa commented 5 years ago

Exactly what @TheCodez mentioned!

dashesy commented 5 years ago

@fmassa

I believe we won't be adding support in nn for View, see pytorch/pytorch#2486 But it's fairly straightforward to create your own module for it.

It is fairly straightforward to add the simple module (many people do), I have this and plug it before FC layer.

class FCView(nn.Module):
    def __init__(self):
        super(FCView, self).__init__()

    # noinspection PyMethodMayBeStatic
    def forward(self, x):
        n_b = x.data.size(0)
        x = x.view(n_b, -1)
        return x

    def __repr__(self):
        return 'view(nB, -1)'

but unless it is part of the official PyTorch many people will not use it. As also requested here

Looking at the torchvision.models.resnet34 this is forward:

class ResNet(nn.Module):

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)

        return x

This means resnet34 could have been

class ResNet(nn.Sequential):

No need to implement forward at all. If it was not for the reshape. Then manipulating it would have been more straightforward and we would not need to treat it differently. resnet34 is just an example, but in general it would be nice to also have a simple reshape nn.module and use it instead of re-implemeting forward.

fmassa commented 5 years ago

We have since then added a nn.Flatten module, which does the job of nn.Reshape for the particular case of converting from a convolution to a fc layer.

No need to implement forward at all. If it was not for the reshape. Then manipulating it would have been more straightforward and we would not need to treat it differently. resnet34 is just an example, but in general it would be nice to also have a simple reshape nn.module and use it instead of re-implemeting forward.

there are cases where having the explicit forward written is better. For example, what if you want to return intermediate features as well?

dashesy commented 5 years ago

@fmassa nn.Flatten would solve most issues, so I should open an issue for torchvision to start using it, so that we could easily manipulate it. For intermediate features I have a Tee module that is similar to nn.Sequence but instead of forwarding x to each internal module consecutively, returns a tuple with all the tensor results. The module that consumes Tee should know about the number of outputs, or accept *args-style input and work with those tensors.

class TeeHeads(nn.Module):
    def __init__(self, *nets):
        """Create multi-head network (multiple outputs)
        :param nets: modules to form a Tee
        :type nets: nn.Module
        """
        super(TeeHeads, self).__init__()
        for idx, net in enumerate(nets):
            self.add_module("{}".format(idx), net)

    def forward(self, *inputs):
        outputs = []
        for module in self._modules.values():
            outputs.append(module(*inputs))
        return outputs
fmassa commented 5 years ago

@dashesy I'm not sure we want to modify the implementations to all leverage nn.Flatten and make everything a single nn.Sequential.

From my view, as soon as we want to do something slightly less custom having the model being a nn.Sequential will just be an annoyance, and we start by removing the nn.Sequential so that we have full control of what the model executes.

Having a nn.Sequential model has its conveniences, e.g., it makes it easier to break the model in the middle. But it also moves away from the programming style that we have encourage since the beginning in PyTorch, which is that the model is code, and is entirely defined by your forward pass, giving you full capability of modifying the model as you wish.

If all the reference implementations leverage nn.Sequential, it might give the impression to new users that that's the only way they can implement models.

I'm adding more people to the discussion, as I think this deserves some more discussion. @soumith @cpuhrsch @zhangguanheng66 @vincentqb thoughts?

makeyourownneuralnetwork commented 4 years ago

did this nn.View() ever get implemented upstream?

makeyourownneuralnetwork commented 4 years ago

You can easily define your own module:

class View(nn.Module):
    def __init__(self, shape):
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

Just plug that into sequential now.

@varunagrawal I made a slight variation by adding a comma so it can cope with both tuples and simple integers as shapes. Seems to work in my limited testing - feedback welcome.

class View(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape,  # extra comma

    def forward(self, x):
        return x.view(*self.shape)
ncuxomun commented 4 years ago

You can easily define your own module:

class View(nn.Module):
    def __init__(self, shape):
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

Just plug that into sequential now.

@varunagrawal I made a slight variation by adding a comma so it can cope with both tuples and simple integers as shapes. Seems to work in my limited testing - feedback welcome.

class View(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape,  # extra comma

    def forward(self, x):
        return x.view(*self.shape)

I implemented View() as shown above, but instead of the requested data being reshaped, I see that View() was returned. Am I doing something different?

dlmgary commented 3 years ago

@ncuxomun I think you forgot to add the super().__init__() in the __init__() method.

ndgnuh commented 1 year ago

I know this issue is closed, but I'd like to share my hack still:

class Reshape(nn.Module):
    def __init__(self, expr: str, mode="reshape"):
        super().__init__()
        assert mode in ["reshape", "view"]
        self.mode = mode
        self.arg_expr, self.ret_expr = expr.replace(" ", "").replace(",", ", ").split("->")
        exec(f"self.get_new_shape = lambda {self.arg_expr}: ({self.ret_expr})")

    def extra_repr(self):
        args = [
            f"[{self.arg_expr}] -> [{self.ret_expr}]".upper(),
            f"mode={self.mode}",
        ]
        return ", ".join(args)

    def forward(self, x):
        shape = self.get_new_shape(*x.shape)
        return getattr(x, self.mode)(*shape)

class Permute(nn.Module):
    def __init__(self, expr: str):
        super().__init__()

        expr = expr.replace(" ", "").replace(",", ", ")
        src_expr, dst_expr = expr.split("->")
        src = src_expr.split(", ")
        dst = dst_expr.split(", ")
        assert len(src) == len(dst)
        self.order = tuple([src.index(d) for d in dst])
        self.src_expr = src_expr
        self.dst_expr = dst_expr

    def extra_repr(self):
        return f"[{self.src_expr}] -> [{self.dst_expr}]".upper()

    def forward(self, x):
        return x.permute(self.order)

See it in action:

image = torch.rand(10, 3, 32, 128)
to_seq = nn.Sequential(
    Reshape("n,c,h,w -> n, c * h, w", mode="view"),
    Permute("n, c * h, w -> n, w, c * h")
)
print(to_seq) # pretty print
seq = to_seq(image)
print(seq.shape)
Sequential(
  (0): Reshape([N, C, H, W] -> [N, C*H, W], mode=view)
  (1): Permute([N, C*H, W] -> [N, W, C*H])
)
torch.Size([10, 128, 96])

Also, these layers seem exportable to ONNX files:

# works fine
torch.onnx.export(to_seq, image, '/tmp/test.onnx',
                  do_constant_folding=True,
                  opset_version=12,
                  input_names=["images"],
                  output_names=["seq"])