silvandeleemput / memcnn

PyTorch Framework for Developing Memory Efficient Deep Invertible Networks
MIT License
251 stars 26 forks source link

About AffineCoupling and AdditiveCoupling #63

Open xuedue opened 3 years ago

xuedue commented 3 years ago

Description

I want to implement the inversion function of MLP, when using AdditiveCoupling, It works image

but when I change to AffineCoupling image

image

In addition, I would like to ask, to achieve reversible MLP, do I need to halve the input and output channels in Fm and Gm, as I wrote above.

Look foward to your answer, thanks!

What I Did

Paste the command(s) you ran and the output.
If there was a crash, please include the traceback here.
silvandeleemput commented 3 years ago

@xuedue Hi, thanks for using MemCNN. Whereas the memcnn.AdditiveCoupling expects Fm and Gm to have a single input x and a single output y of the same shape, memcnn.AffineCoupling expects the input to be of the form (s, t), namely a vector for the scale and a vector for the translation/shift. By default, the memcnn.AffineCoupling does not provide any conversion for this since the adapter parameter is set to None, but you can simply supply one of the following adapters memcnn.AffineAdapterNaive or the memcnn.AffineAdapterSigmoid to the AffineCoupling adapter parameter, which will wrap your methods for Fm and Gm to convert the modules to output the form (s, t).

For example for your code:

self.invertible_module = memcnn.AffineCoupling(
   Fm=MLP(...),
   Gm=MLP(...),
   adapter=memcnn.AffineAdapterSigmoid
)

I hope this solves your issue.

xuedue commented 3 years ago

Thank you for your reply, I modified it according to your description, but it brought another error. image image

If I use AffineAdapterNaive instead of AffineAdapterSigmoid, it works.

xuedue commented 3 years ago

@xuedue Hi, thanks for using MemCNN. Whereas the memcnn.AdditiveCoupling expects Fm and Gm to have a single input x and a single output y of the same shape, memcnn.AffineCoupling expects the input to be of the form (s, t), namely a vector for the scale and a vector for the translation/shift. By default, the memcnn.AffineCoupling does not provide any conversion for this since the adapter parameter is set to None, but you can simply supply one of the following adapters memcnn.AffineAdapterNaive or the memcnn.AffineAdapterSigmoid to the AffineCoupling adapter parameter, which will wrap your methods for Fm and Gm to convert the modules to output the form (s, t).

For example for your code:

self.invertible_module = memcnn.AffineCoupling(
   Fm=MLP(...),
   Gm=MLP(...),
   adapter=memcnn.AffineAdapterSigmoid
)

I hope this solves your issue.

Fm and Gm need to have input x and output y of the same shape. If I want to implement a reversible MLP with different input and output channels, what should I do? For example, the input of MLP is 100 dimensions, and the output is 2 dimensions.

silvandeleemput commented 3 years ago

Fm and Gm need to have input x and output y of the same shape. If I want to implement a reversible MLP with different input and output channels, what should I do? For example, the input of MLP is 100 dimensions, and the output is 2 dimensions.

I assume here that with dimensions you mean the number of channels in the dimension you split on.

A simple way of doing so is to make a reversible MLP with the same shape input and outputs, then at the end extract only the first 2 channels, using slicing:

output = self.invertible_module(x)
output_reduced = output[:, :2, :]  # depends a bit on the shape of your output

A more complex way, if you want to keep it fully invertible, is to create a torch.nn.Module that reshapes/reorders the data so it retains the same number of elements but now with 2 channels. In this way, you can still wrap it using the memcnn.InvertibleModuleWrapper.

xuedue commented 3 years ago

Fm and Gm need to have input x and output y of the same shape. If I want to implement a reversible MLP with different input and output channels, what should I do? For example, the input of MLP is 100 dimensions, and the output is 2 dimensions.

I assume here that with dimensions you mean the number of channels in the dimension you split on.

A simple way of doing so is to make a reversible MLP with the same shape input and outputs, then at the end extract only the first 2 channels, using slicing:

output = self.invertible_module(x)
output_reduced = output[:, :2, :]  # depends a bit on the shape of your output

A more complex way, if you want to keep it fully invertible, is to create a torch.nn.Module that reshapes/reorders the data so it retains the same number of elements but now with 2 channels. In this way, you can still wrap it using the memcnn.InvertibleModuleWrapper.

Well, Like the following MLP example image If I change the output to 100 dimensions and only take two of them,It doesn't make sense. For MLP, the input and output are both (batch size, dimension), I think there is no way to convert (batch size, 100) to (batch size, 2). Instead of changing the data shape, are there any other functions or modules that can implement a reversible MLP network which is like the MLP network in the picture above?

silvandeleemput commented 3 years ago

If I change the output to 100 dimensions and only take two of them,It doesn't make sense.

Why? Could you elaborate? Doesn't this work for your use case?

Alternatively, you could also take the mean over the first 50 channels and the second 50 channels to reduce your output from (batch_size, 100) -> (batch_size, 2), if you're more comfortable with that (very similar to an average pooling operation).

output_100_channels = reversible_mlp_network.forward(input_vector)   
print(output_100_channels.shape)  # output shape should be (batch_size, 100)
output_2_channels = torch.cat((torch.mean(output_100_channels[:, :50], dim=1, keepdim=True), torch.mean(output_100_channels[:, 50:], dim=1, keepdim=True)), dim=1)
print(output_2_channels .shape)  # output shape should be (batch_size, 2)

Could you maybe tell me a little bit more about which parts of your network do you want to make reversible? How many layers do you intend to make your network? Typical reversible networks in the literature like RevNet break at some point with their reversible nature at the end of the network before computing the loss and applying backpropagation.

silvandeleemput commented 3 years ago

Thank you for your reply, I modified it according to your description, but it brought another error. image image

If I use AffineAdapterNaive instead of AffineAdapterSigmoid, it works.

Btw, if you still want to use the AffineAdapterSigmoid with the AffineCoupling, you can use the following modified AffineAdapterSigmoid implemention, which should fix your error:

class AffineAdapterSigmoidModified(torch.nn.Module):
    """ Sigmoid based affine adapter, modified to work for 2 dimensional outputs

        Partitions the output h of f(x) = h into s and t by extracting every odd and even channel
        Outputs sigmoid(s), t
    """
    def __init__(self, module):
        super(AffineAdapterSigmoidModified, self).__init__()
        self.f = module

    def forward(self, x):
        h = self.f(x)
        assert h.shape[1] % 2 == 0  # nosec
        scale = torch.sigmoid(h[:, 1::2] + 2.0)
        shift = h[:, 0::2]
        return scale, shift
xuedue commented 3 years ago

If I change the output to 100 dimensions and only take two of them,It doesn't make sense.

Why? Could you elaborate? Doesn't this work for your use case?

Alternatively, you could also take the mean over the first 50 channels and the second 50 channels to reduce your output from (batch_size, 100) -> (batch_size, 2), if you're more comfortable with that (very similar to an average pooling operation).

output_100_channels = reversible_mlp_network.forward(input_vector)   
print(output_100_channels.shape)  # output shape should be (batch_size, 100)
output_2_channels = torch.cat((torch.mean(output_100_channels[:, :50], dim=1, keepdim=True), torch.mean(output_100_channels[:, 50:], dim=1, keepdim=True)), dim=1)
print(output_2_channels .shape)  # output shape should be (batch_size, 2)

Could you maybe tell me a little bit more about which parts of your network do you want to make reversible? How many layers do you intend to make your network? Typical reversible networks in the literature like RevNet break at some point with their reversible nature at the end of the network before computing the loss and applying backpropagation.

The reversible network I want to achieve is as follows: image image

After getting the 1024 output, I will make some modifications on the 1024-dimensional output, and then reversibly get the 4096-dimensional input with the same dimension but different values.

So if I average the output, it does not meet my needs. If the output is 4096 dimensions, and then the averaging operation is performed to obtain the 1024 dimension, and then the 1024 dimension is modified, the original 4096 input cannot be obtained through the modified 1024 dimension output, because the additional averaging operation is not reversible.

silvandeleemput commented 3 years ago

Ok, thanks for clarifying your question. First, I would suggest making layers 1-6 invertible. This should be simple (the in_features/out_features ratio is 1:1, which is what memcnn supports very well) by wrapping the intermediate linear blocks with the Affine or Additive Coupling blocks as you did before. (These can subsequently be wrapped in the memcnn.InvertibleModuleWrapper() to achieve memory savings.)

The final layer (7), does it have to be 1024 output features? Why not increase the previous layers (1-6) to 1024 features? Or decrease the out_features to 512? Both of those approaches would be the simplest in your case, so you can just wrap it like you did with layers (1-6). Otherwise, you could duplicate your output once, but I doubt that is what you would want.

If you also want to make you first layer (1) invertible, this seems to be the most problematic. If this is really desirable, maybe consider making it have a matching number of in_features and out_features as well.

An alternative for the first layer would be to try the very experimental pad and subsequent crop trick. I.e. apply zero padding to the output of the linear layer, until in_features == out_features, wrapping that with a Coupling block, subsequently cropping the padded part, right after the coupling block. I could provide an example if you're interested.

xuedue commented 3 years ago

Ok, thanks for clarifying your question. First, I would suggest making layers 1-6 invertible. This should be simple (the in_features/out_features ratio is 1:1, which is what memcnn supports very well) by wrapping the intermediate linear blocks with the Affine or Additive Coupling blocks as you did before. (These can subsequently be wrapped in the memcnn.InvertibleModuleWrapper() to achieve memory savings.)

The final layer (7), does it have to be 1024 output features? Why not increase the previous layers (1-6) to 1024 features? Or decrease the out_features to 512? Both of those approaches would be the simplest in your case, so you can just wrap it like you did with layers (1-6). Otherwise, you could duplicate your output once, but I doubt that is what you would want.

If you also want to make you first layer (1) invertible, this seems to be the most problematic. If this is really desirable, maybe consider making it have a matching number of in_features and out_features as well.

An alternative for the first layer would be to try the very experimental pad and subsequent crop trick. I.e. apply zero padding to the output of the linear layer, until in_features == out_features, wrapping that with a Coupling block, subsequently cropping the padded part, right after the coupling block. I could provide an example if you're interested.

Thank you very much for your reply

What I want to explain is that I need to maintain the reversibility of the entire network, so that I can reversibly get 9216-dimensional input after operating on 1024-dimensional output. Here, the 9216-dimensional input is immutable and must include in the reversible network. My most important goal is to get this 9216-dimensional input through a reversible network.

Of course, if I modify the 1024-dimensional output to 9216-dimensional output, the problem can be solved, but changing the input to 1024-dimensional is for better follow-up operations.

Can the network be reversible if the input must be a 9216-dimensional vector and the output is a 1024-dimensional vector?

If the input and output must be of the same dimension, then there will be relatively large limitations in practical applications.

silvandeleemput commented 3 years ago

Can the network be reversible if the input must be a 9216-dimensional vector and the output is a 1024-dimensional vector?

For as far as I know this can't be made reversible for the general case. That's because you still require all the elements of the final output to reconstruct the input (all 9216 features). So you can't throw anything away from the 9216-dimensional output vector without throwing away some information about how to reconstruct the input. I think this is because of the mathematical nature of the reversible couplings. The outputs are also partially encoding the inputs after all.

I just tried to work it out using the crop and pad strategy I described earlier, but this still won't work because of the above reasons. After applying the padding inside the coupling the output is invertible, but after cropping it would no longer be the case.

Of course, there are some non-practical special cases for which it would work, e.g. if the input dimension is a 1024-feature vector 9 times duplicated or a 1024-feature vector with the other 8192-features all being zero.

If the input and output must be of the same dimension, then there will be relatively large limitations in practical applications.

It depends on what you want. In the literature (and in practice), they tend to mix invertible operations and normal non-invertible operations for things that change output dimensions and shapes like pooling. MemCNN was designed with this mixing strategy in mind, so that you can easily turn invertible modules memory-efficient using the memcnn.InvertibleModuleWrapper, while still relying on normal operations where they are more sensical.

There are also some special techniques to change dimensions, like invertible down-sampling:

image

Source: https://arxiv.org/pdf/1802.07088.pdf But these relate to special dimensions and still maintain a 1:1 mapping regarding the number of elements.

You might be able to do something similar by splitting your output of 9216 elements over 9 batches of 1024. E.g. if your input has a batch size of 5 with 9216 features, your output becomes a batch size of 5x9=45 with 1024 features. In this way, you can keep it fully invertible.

xuedue commented 3 years ago

Ok, thanks for clarifying your question. First, I would suggest making layers 1-6 invertible. This should be simple (the in_features/out_features ratio is 1:1, which is what memcnn supports very well) by wrapping the intermediate linear blocks with the Affine or Additive Coupling blocks as you did before. (These can subsequently be wrapped in the memcnn.InvertibleModuleWrapper() to achieve memory savings.) The final layer (7), does it have to be 1024 output features? Why not increase the previous layers (1-6) to 1024 features? Or decrease the out_features to 512? Both of those approaches would be the simplest in your case, so you can just wrap it like you did with layers (1-6). Otherwise, you could duplicate your output once, but I doubt that is what you would want. If you also want to make you first layer (1) invertible, this seems to be the most problematic. If this is really desirable, maybe consider making it have a matching number of in_features and out_features as well. An alternative for the first layer would be to try the very experimental pad and subsequent crop trick. I.e. apply zero padding to the output of the linear layer, until in_features == out_features, wrapping that with a Coupling block, subsequently cropping the padded part, right after the coupling block. I could provide an example if you're interested.

Thank you very much for your reply

What I want to explain is that I need to maintain the reversibility of the entire network, so that I can reversibly get 9216-dimensional input after operating on 1024-dimensional output. Here, the 9216-dimensional input is immutable and must include in the reversible network. My most important goal is to get this 9216-dimensional input through a reversible network.

Of course, if I modify the 1024-dimensional output to 9216-dimensional output, the problem can be solved, but changing the input to 1024-dimensional is for better follow-up operations.

Can the network be reversible if the input must be a 9216-dimensional vector and the output is a 1024-dimensional vector?

If the input and output must be of the same dimension, then there will be relatively large limitations in practical applications.

Thank you very much for your explanation. If I have other problems in the following use, I will disturb you again.