silvandeleemput / memcnn

PyTorch Framework for Developing Memory Efficient Deep Invertible Networks
MIT License
251 stars 26 forks source link

Multi-input / multi-output reversible layers #43

Closed cetmann closed 4 years ago

cetmann commented 4 years ago

Description

Sorry for the many questions over the last few days, I'd be happy to contribute.

I have a class of networks in which the channels are split at some point, after which several (reversible) layers are applied to each chunk of the channels. Then, these are recombined by concatenating the channels again. Channel splitting and concatenating (with fixed chunk size) are invertible functions and should work really well with this.

A possible implementation is this one:

class SplitChannels(torch.nn.Module):
    def __init__(self, split_location):
        self.split_location = split_location
        super(SplitChannels, self).__init__()

    def forward(self, x):
        return x[:,:self.split_location], x[:,self.split_location:]

    def inverse(self, x, y):
        return torch.cat(x, y, dim=1)

class ConcatenateChannels(torch.nn.Module):
    def __init__(self, split_location):
        self.split_location = split_location
        super(ConcatenateChannels, self).__init__()

    def forward(self, x, y):
        return torch.cat(x, y, dim=1)

    def inverse(self, x):
        return x[:,:self.split_location], x[:,self.split_location:]

One could also take the input as a list. The output should however always be a tensor or a tuple of tensors, because the InvertibleCheckpointFunction is a autograd.Function object (whose forward-method wraps this) and does not accept lists as output types.

In my architecture, the splitting and recombination appears in a nested fashion, such that an implementation as a single reversible layer is not possible (like it is done with the coupling blocks, where the splitting and recombination is done in the same layer).

With the current implementation, this is not possible. The error thrown is the following:

/usr/local/lib/python3.6/dist-packages/memcnn/models/revop.py in forward(self, xin)
    135         """
    136         if not self.disable:
--> 137             y = InvertibleCheckpointFunction.apply(xin, self._fn.forward, self._fn.inverse, self.keep_input, self.num_bwd_passes, *[p for p in self._fn.parameters() if p.requires_grad])
    138             if not self.keep_input:
    139                 if not pytorch_version_one_and_above:

/usr/local/lib/python3.6/dist-packages/memcnn/models/revop.py in forward(ctx, input_t, fn, fn_inverse, keep_input, num_bwd_passes, *weights)
     24             output = ctx.fn(x)
     25 
---> 26         detached_output = output.detach_()  # Detaches y in-place (inbetween computations can now be discarded)
     27 
     28         # store these tensor nodes for backward pass

AttributeError: 'list' object has no attribute 'detach_'

What I tried

So I already tried converting the code in revop.py such that it operates on tuples instead of tensors if needed, e.g.

        if isinstance(output, tuple):
            # Detaches y in-place (inbetween computations can now be discarded)
            detached_output = [element.detach_() for element in output] 
            detached_output = tuple(detached_output)
        else:
            detached_output = output.detach_()

A problem is that the detaching does not work in-place here (no idea why), so one would have to use detach() instead of detach_(). If I use detach() however, Pytorch will also throw an error:

x = torch.rand(1,3,32,32)
print(InvertibleModuleWrapper(SplitChannels(2))(x))
RuntimeError: setStorage: sizes [2, 32, 32], strides [1024, 32, 1], and storage offset 0 requiring a storage size of 2048 are out of bounds for storage with numel 0
silvandeleemput commented 4 years ago

@cetmann Hi, this is actually something that I would like to have in MemCNN as well and it should be not that hard with the current setup. Although, it might be a bit tricky to get right.

If you want to have a look at it, check how checkpointing in PyTorch is implemented, which also allows for multi-input/multi-output operations: https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint

I might have a look at this issue myself this weekend.

EDIT: If you want to implement this I can recommend using tuples instead of lists for this.

silvandeleemput commented 4 years ago

@cetmann Also, looking at what you are trying to achieve, you might want to have a look at how splitting and combining is done in the AdditiveCoupling and AffineCoupling classes. These also split inputs in half by channels and these can be easily wrapped by the InvertibleModuleWrapper.

cetmann commented 4 years ago

@silvandeleemput Sounds great! If you need any help or suggestions, I'd be more than happy to help. I'll give it a shot myself over the next few days.

The impossibility of using detach_ instead of detach above by the way happened because index slicing (which I used) only applies a view-operation. One would have to clone the slices before detaching, but I wonder whether that still does what we want -- why do the detaching in-place then?

Can you give me a pointer why the input needs to be detached in-place and everything else is done with the standard detach operation? If detaching in-place is not necessary, I'd suggest not doing that, otherwise people might run into the same problem that I ran into.

silvandeleemput commented 4 years ago

@silvandeleemput Sounds great! If you need any help or suggestions, I'd be more than happy to help. I'll give it a shot myself over the next few days.

Great, good luck! Here are some pointers:

cetmann commented 4 years ago

Those were good suggestions. I have a working version now and am preparing some tests and will create a pull request once these are done.

I first didn't want to mess with your function arguments, so I originally went for list-inputs (keras-style) instead of combining inputs and weights like you proposed, which ended up being much easier. Internally, I now just convert everything to tuples (and in the case of single arguments, unpack these before returning).