silvandeleemput / memcnn

PyTorch Framework for Developing Memory Efficient Deep Invertible Networks
MIT License
251 stars 26 forks source link

feat(coupling): Inplace Padding #54

Closed ClashLuke closed 3 years ago

ClashLuke commented 4 years ago

Currently the functions get executed and input/output are required to have the exact same shape.\ That's why, in the past, I've used images padded along the channel dimension as input to my revnets.\ Knowing how terribly inefficient that is, the only possible other way to do it is to pad x0 and x1 before adding them to the function outputs.\ Similary, inplace dropping, where only parts of the channels are kept, would be equally useful.

Known issues: 1) You can't just replicate the functions anymore, since one takes 3 channels and the other out_channels as in_channels

If you believe that this would be a useful addition to the memcnn library, I could implement that today.

silvandeleemput commented 4 years ago

Hi @ClashLuke, thank you for raising this issue. I understand that you are referring to the precondition that the input and output of the couplings, which must have the same shape, have been a limiting factor for your use case. It would be great if you could elaborate on your use case and provide some examples with input and output shapes (e.g. you want to reduce a larger input to a smaller output say: 2 x 64 x 64 to 2 x 62 x 62?)

If what you are trying to achieve are non-volume preserving couplings. Then I can inform you that support for non-volume preserving couplings could be a welcome addition to MemCNN. However, so far I have found that this can be best achieved by cleverly wrapping volume-preserving couplings and/or choosing good functions for F an G. So this might be the route you want to pursue.

ClashLuke commented 4 years ago

Thanks for the response, @silvandeleemput.\ My particular usecase would be giving my model 3x64x64 data and having 128x64x64 as a size for the revnet. Similary, on the final classification layer, 2x64x64 being supported without the need to pad both input and output. If I were to pad these, I could run the revnet, but the memory would explode. The model would store the first block, the last one and then in-between, and i want to get rid of the first and last.\ Therefore my proposal is a simple feature indexing which allows the revblock itself to contain the padding, rather than forcing the user to provide it. This way i could pass in my 3x64x64 data and still have a full i-revnet.

You're right though, passing in Nx32x32 into a Nx64x64 input would be intruiging as well. This could be done by giving AdditiveCoupling a target size on initialization. This way it would know the input and output sizes, and simply crop them as needed. \ The actual layer input and output still have the same size - just that only the center part is changed. The rest just stayed the same.\ Similarly the input could be cropped too, to effectively allow reversible strides.

I'm not entirely sure if what I'm saying makes sense, but it's very trivial to implement. Perhaps the implementation would be more concise than me.

silvandeleemput commented 4 years ago

If I understand you correctly, you want to pad your input channel dimension (with zeros or repeats?) before feeding it to a chain of reversible couplings. Since the padded input tensor wouldn't contain any new information you want the coupling to free the input tensor as well and reconstruct it on the backward pass. Is that correct?

If so, I think you could best write a custom invertible nn.Module wrapping a Coupling class which performs the padding on forward and inverse. By finally wrapping that custom class with the memcnn.InvertibleModuleWrapper it should then take care of the memory management.

ClashLuke commented 4 years ago

Created a pull request. Perhaps that illustrates my point a bit better. The idea is that you have the padding and cropping inside of your reversible function.\ Of course, cropping isn't reversible, but for all we care it is. You don't really need the gradients at cropped positions anyway.\ Regarding padding: Doing that inside of the revblock saves quite a bit of memory, which is why i'd very much prefer doing it this way rather than doing it the "normal" way.\ Naturally, to cover these edge cases, cursed code had to be written. I hope eval's arent an issue for you, because I didn't see any other way of indexing a variable-sized tensor without allocating new memory.

silvandeleemput commented 4 years ago

Thanks for the PR it clarifies your idea a bit better. I think you could still use the solution I provided above to cover this use case though. You can use the following tested example:

import torch
import torch.nn as nn
import memcnn

# NOTE: I only programmed this out for channels and assumes a 2D input
# This module only applies zero padding for simplicity
class CropPad2D(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super(CropPad2D, self).__init__()
        self.output_size = output_size
        self.input_size = input_size

    def forward(self, inp):
        assert(inp.ndim == 4)
        channels_diff = self.output_size[1] - self.input_size[1]
        assert channels_diff >= 0
        return torch.nn.functional.pad(input=inp, pad=[0, 0, 0, 0, 0, channels_diff], mode='constant', value=0)

    def inverse(self, out):
        assert (out.ndim == 4)
        b, c, h, w = self.input_size
        return out[0:b, 0:c, 0:h, 0:w].clone()

# define a new torch Module wrapper for an invertible coupling with cropping and padding
class CouplingCropPad(nn.Module):
    def __init__(self, coupling, input_size, output_size):
        super(CouplingCropPad, self).__init__()
        self.coupling = coupling
        self.input_crop = CropPad2D(input_size=input_size, output_size=output_size)

    def forward(self, x):
        x = self.input_crop(x)
        y = self.coupling(x)
        return y

    def inverse(self, y):
        x = self.coupling.inverse(y)
        x = self.input_crop.inverse(x)
        return x

# ===== Example code starts here =====

# define an example new torch Module with a sequence of operations: Relu o BatchNorm2d o Conv2d
class ExampleOperation(nn.Module):
    def __init__(self, channels):
        super(ExampleOperation, self).__init__()
        self.seq = nn.Sequential(
                                    nn.Conv2d(in_channels=channels, out_channels=channels,
                                              kernel_size=(3, 3), padding=1),
                                    nn.BatchNorm2d(num_features=channels),
                                    nn.ReLU(inplace=True)
                                )

    def forward(self, x):
        return self.seq(x)

# generate some random input data (batch_size, num_channels, y_elements, x_elements)
X = torch.rand(2, 2, 8, 8)

# turn the ExampleOperation invertible using an additive coupling
coupling = memcnn.AdditiveCoupling(
    Fm=ExampleOperation(channels=10 // 2),
    Gm=ExampleOperation(channels=10 // 2)
)
modified_coupling = CouplingCropPad(coupling=coupling, input_size=(2, 2, 8, 8), output_size=(2, 10, 8, 8))
model = memcnn.InvertibleModuleWrapper(fn=modified_coupling, keep_input=True, keep_input_inverse=True)

# test that it is actually a valid invertible module (has a valid inverse method)
assert memcnn.is_invertible_module(model, test_input_shape=X.shape)
model.eval()

Y = model.forward(X)
print(Y.shape)  # 2 channels

X2 = model.inverse(Y)
print(X2.shape)  # 10 channels

I also noted that while it should be relatively safe to perform padding before a coupling (remains invertible), it is usually not safe to apply the cropping after a coupling, since you potentially throw away information making the module no longer invertible. If you still want to do it then you have to build in some constraints into your functions Fm and Gm to get that working. However, this seems rather tricky.

Next, I am not a big fan of the usage of eval, so I would like to avoid that if possible, maybe you could first check if the provided solution works for you.

ClashLuke commented 4 years ago

Thanks for the input.\ If you'd rather have many implementations, each for one dimension, then sure, we could go this route.\ You're right, cropping might be dangerous, so let's leave that out for now. Padding is way more important to me.\ The major change introduced here however is that the padding is reversible and hence not stored in the computation graph. If I were to simply add this operation in front of the model (or inside of my layer), I'd either have one more operation stored in front of the computation graph or multiple reversible graphs, resulting in heavily increased memory consumption.\ Our way however, there is only one graph.