silvandeleemput / memcnn

PyTorch Framework for Developing Memory Efficient Deep Invertible Networks
MIT License
251 stars 26 forks source link

Keep_input=False, keep_input_inverse=False issue both in the minimal example and a custom model #38

Closed ibro45 closed 4 years ago

ibro45 commented 4 years ago

I'm not able to run my models with keep_input=False and keep_input_inverse=False. While hunting down the issue, I tried to assert that module is invertible using memcnn.is_invertible_module, as shown in the minimal example. Doing that resulted in segmentation fault, which happens even for the minimal example if its keep_input and keep_input_inverse are set to False.

What I Did

Minimal Example

Changed the keep_input and keep_input_inverse to False and assert's to print's.

import torch
import torch.nn as nn
import memcnn

# define a new torch Module with a sequence of operations: Relu o BatchNorm2d o Conv2d
class ExampleOperation(nn.Module):
    def __init__(self, channels):
        super(ExampleOperation, self).__init__()
        self.seq = nn.Sequential(
                                    nn.Conv2d(in_channels=channels, out_channels=channels,
                                              kernel_size=(3, 3), padding=1),
                                    nn.BatchNorm2d(num_features=channels),
                                    nn.ReLU(inplace=True)
                                )

    def forward(self, x):
        return self.seq(x)

# generate some random input data (batch_size, num_channels, y_elements, x_elements)
X = torch.rand(2, 10, 8, 8)

# application of the operation(s) the normal way
model_normal = ExampleOperation(channels=10)
model_normal.eval()

Y = model_normal(X)

# turn the ExampleOperation invertible using an additive coupling
invertible_module = memcnn.AdditiveCoupling(
    Fm=ExampleOperation(channels=10 // 2),
    Gm=ExampleOperation(channels=10 // 2)
)

# test that it is actually a valid invertible module (has a valid inverse method)
print(memcnn.is_invertible_module(invertible_module, test_input_shape=X.shape))

# wrap our invertible_module using the InvertibleModuleWrapper and benefit from memory savings during training
invertible_module_wrapper = memcnn.InvertibleModuleWrapper(fn=invertible_module, keep_input=False, keep_input_inverse=False)

# by default the module is set to training, the following sets this to evaluation
# note that this is required to pass input tensors to the model with requires_grad=False (inference only)
invertible_module_wrapper.eval()

# test that the wrapped module is also a valid invertible module
print(memcnn.is_invertible_module(invertible_module_wrapper, test_input_shape=X.shape))

# compute the forward pass using the wrapper
Y2 = invertible_module_wrapper.forward(X)

# the input (X) can be approximated (X2) by applying the inverse method of the wrapper on Y2
X2 = invertible_module_wrapper.inverse(Y2)

# test that the input and approximation are similar
print(torch.allclose(X, X2, atol=1e-06))

First assert goes well, but the second one fails since the inputs are not being kept anymore. Traceback:

python minimal.py     
True
zsh: segmentation fault (core dumped)  python minimal.py

Furthermore, commenting out the asserts as well as invertible_module_wrapper.eval() and keeping keep_input and keep_input_inverse as False gives the following traceback in the minimal example.

Traceback (most recent call last):
  File "minimal.py", line 49, in <module>
    Y2 = invertible_module_wrapper.forward(X)
  File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/memcnn/models/revop.py", line 120, in forward
    xin.register_hook(hook=partial(signal_hook, valid_states=self._valid_states, state_index=self._state_counter))
  File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/tensor.py", line 198, in register_hook
    raise RuntimeError("cannot register a hook on a tensor that "
RuntimeError: cannot register a hook on a tensor that doesn't require gradient

My model

Training with keeping the inputs works well, but fails if the inputs have to be calculated during the backpropagation.

Traceback:

Traceback (most recent call last):
  File "train.py", line 35, in <module>
    model.optimize_parameters()
  File "/rwthfs/rz/cluster/home/ft002207/ibRevGAN/models/unpaired_revgan3d_model.py", line 154, in optimize_parameters
    self.forward()
  File "/rwthfs/rz/cluster/home/ft002207/ibRevGAN/models/unpaired_revgan3d_model.py", line 87, in forward
    self.fake_B = self.netG_A(self.real_A)
  File "/rwthfs/rz/cluster/home/ft002207/ibRevGAN/models/unpaired_revgan3d_model.py", line 49, in <lambda>
    self.netG_A = lambda x: self.netG(x)
  File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/rwthfs/rz/cluster/home/ft002207/ibRevGAN/models/networks3d.py", line 824, in forward
    out64 = self.down_tr64(out32, inverse)
  File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/rwthfs/rz/cluster/home/ft002207/ibRevGAN/models/networks3d.py", line 736, in forward
    down = self.down_conv_ab(x)
  File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 480, in forward
    self.padding, self.dilation, self.groups)
RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR

Unfortunately, cuDNN errors are not really informative since I have already encountered the same RuntimeError which at the end had nothing to do with the what's stated in the message (the end of the message was exactly the same as above). However, I believe the two cases are connected so I'm submitting them as a single issue.

silvandeleemput commented 4 years ago

@ibro45 Hi, thanks for your interest in MemCNN. I think you stumbled on an interesting bug with the is_invertible_module method.

A fundamental limitation of the memcnn.InvertibleModuleWrapper which hooks into the way the memory saving is implemented. Say you have f.forward(X) = Y with f an invertible module using the wrapper. On the forward pass, X is discarded by setting the underlying storage to have size zero. If you try to access the data from X at this moment, it will yield a segfault. This works in a similar way for the inverse.

Now looking at the code for is_invertible_module:

https://github.com/silvandeleemput/memcnn/blob/e0a0288e5b189443c87d299881c31239d6e9b029/memcnn/models/revop.py#L263-L295

We can see that the line 287 roughly tests if f.inverse(f.forward(X)) == X. However, after f.forward(X), X is already in an invalid state and will remain so until the output X2 =f.inverse(f.forward(X)) is compared X2 == X. This will yield a segfault when using a memcnn.InvertibleModuleWrapper for f.

For now I fixed the is_invertible_module method for modules wrapped with the memcnn.InvertibleModuleWrapper by simply testing the underlying function only in that case. This should avoid seg-faults by using the method altogether.

I rolled out the released for MemCNN 1.2.1 which has the fix.

ibro45 commented 4 years ago

Hi, thank you for the explanation and the prompt fix of is_invertible_module method, it does work well now! Do you have a clue about why does RuntimeError: cannot register a hook on a tensor that doesn't require gradient occur?

You can reproduce it with the following modified minimal example where I removed the asserts and added two variables - KEEP to control keep_input and keep_input_inverse and EVAL to control if the model is in eval() mode or not.

import torch
import torch.nn as nn
import memcnn

KEEP = True
EVAL = False

# define a new torch Module with a sequence of operations: Relu o BatchNorm2d o Conv2d
class ExampleOperation(nn.Module):
    def __init__(self, channels):
        super(ExampleOperation, self).__init__()
        self.seq = nn.Sequential(
                                    nn.Conv2d(in_channels=channels, out_channels=channels,
                                              kernel_size=(3, 3), padding=1),
                                    nn.BatchNorm2d(num_features=channels),
                                    nn.ReLU(inplace=True)
                                )

    def forward(self, x):
        return self.seq(x)

# generate some random input data (batch_size, num_channels, y_elements, x_elements)
X = torch.rand(2, 10, 8, 8)

# application of the operation(s) the normal way
model_normal = ExampleOperation(channels=10)
model_normal.eval()

Y = model_normal(X)

# turn the ExampleOperation invertible using an additive coupling
invertible_module = memcnn.AdditiveCoupling(
    Fm=ExampleOperation(channels=10 // 2),
    Gm=ExampleOperation(channels=10 // 2)
)

invertible_module_wrapper = memcnn.InvertibleModuleWrapper(fn=invertible_module, 
                                                           keep_input=KEEP, 
                                                           keep_input_inverse=KEEP)

if EVAL:
    invertible_module_wrapper.eval()
else:
    invertible_module_wrapper.train()

Y2 = invertible_module_wrapper.forward(X)
X2 = invertible_module_wrapper.inverse(Y2)

While the invertible module seems to work both with keeping and not keeping the inputs when it's in evaluation mode, it breaks if it's in train mode (EVAL = False), no matter if the inputs are kept or not.

Traceback:

RuntimeError                              Traceback (most recent call last)
<ipython-input-18-17483c8f6451> in <module>
     45 
     46 
---> 47 Y2 = invertible_module_wrapper.forward(X)
     48 X2 = invertible_module_wrapper.inverse(Y2)

~/anaconda3/envs/maastro/lib/python3.7/site-packages/memcnn/models/revop.py in forward(self, xin)
    118             if self.training:
    119                 self._valid_states.append(True)
--> 120                 xin.register_hook(hook=partial(signal_hook, valid_states=self._valid_states, state_index=self._state_counter))
    121                 y.register_hook(hook=partial(backward_hook, keep_input=self.keep_input,
    122                                              compute_input_fn=self._fn.inverse, compute_output_fn=self._fn.forward,

~/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/tensor.py in register_hook(self, hook)
    196         """
    197         if not self.requires_grad:
--> 198             raise RuntimeError("cannot register a hook on a tensor that "
    199                                "doesn't require gradient")
    200         if self._backward_hooks is None:

RuntimeError: cannot register a hook on a tensor that doesn't require gradient
silvandeleemput commented 4 years ago

@ibro45 Hi, yes I think I can explain that. The input variable (X) was not set to require gradients. The input variable must always have this for training in PyTorch. This can be achieved with:

X.requires_grad = True

If you set this property on the tensor before calling forward when training, you should be fine.

I hope this helps you.

silvandeleemput commented 4 years ago

This seems solved, I am closing the issue now.