silvandeleemput / memcnn

PyTorch Framework for Developing Memory Efficient Deep Invertible Networks
MIT License
251 stars 26 forks source link

Residual connection - Trying to backward through the graph a second time #39

Closed ibro45 closed 4 years ago

ibro45 commented 4 years ago

Description

Hi @silvandeleemput , sorry for being active in issues section this week. I'm trying to implement a residual connection in my 3D-UNet-like architecture, but I keep getting RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.. I've also asked about this issue on PyTorch forum here.

What I Did

I made this minimal example of the issue I have. I've been trying to debug it but I can't see anything wrong in the code I've written so far, so I am wondering if you have encountered anything similar while developing MemCNN because I am not sure where the issue is coming from. The line where the problem starts is out = down in the DownTransition class.

Minimal Example

import torch
from torch import nn, optim
import memcnn

class RevBlock(nn.Module):
    def __init__(self, nchan):
        super(RevBlock, self).__init__()

        invertible_module = memcnn.AdditiveCoupling(
            Fm=self.build_conv_block(nchan//2),
            Gm=self.build_conv_block(nchan//2)
        )

        self.rev_block = memcnn.InvertibleModuleWrapper(fn=invertible_module, 
                                                        keep_input=True, 
                                                        keep_input_inverse=True)

    def build_conv_block(self, nchan):
        return nn.Sequential(nn.Conv3d(nchan, nchan, kernel_size=5, padding=2),
                             nn.BatchNorm3d(nchan),
                             nn.PReLU(nchan))

    def forward(self, x, inverse=False):
        if inverse:
            return self.rev_block.inverse(x)
        else:
            return self.rev_block(x)

class DownTransition(nn.Module):
    def __init__(self, inChans, nConvs):
        super(DownTransition, self).__init__()
        outChans = 2*inChans
        self.down_conv_ab = self.build_down_conv(inChans, outChans)
        self.down_conv_ba = self.build_down_conv(inChans, outChans)
        self.core = nn.Sequential(*[RevBlock(outChans) for _ in range(nConvs)])
        self.relu = nn.PReLU(outChans)

    def build_down_conv(self, inChans, outChans):
        return nn.Sequential(nn.Conv3d(inChans, outChans, kernel_size=2, stride=2),
                             nn.BatchNorm3d(outChans),
                             nn.PReLU(outChans))

    def forward(self, x, inverse=False):
        if inverse:
            down_conv = self.down_conv_ba
            core = reversed(self.core)
        else:
            down_conv = self.down_conv_ab
            core = self.core

        down = down_conv(x)
        out = down    # the reason it breaks
        for block in core:
            out = block(out, inverse=inverse)

        out = out + down
        return self.relu(out)

model = DownTransition(16, 2)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for i in range(10):
    optimizer.zero_grad()
    data, target = torch.rand((2,16,64,64,64)), torch.rand((2,32,32,32,32))
    out = model.forward(data)
    loss = criterion(out, target)
    loss.backward()
    optimizer.step()

Traceback

Traceback (most recent call last):
  File "minimal.py", line 71, in <module>
    loss.backward()
  File "/home/bro/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/tensor.py", line 166, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/bro/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
ibro45 commented 4 years ago

Apparently, this issue is present in PyTorch versions 1.3.1 and 1.4.0 but not in 1.1.0.

ibro45 commented 4 years ago

Interestingly enough, it is present for 1.1.0 when the above example is used in my GAN setup.

silvandeleemput commented 4 years ago

Hi @silvandeleemput , sorry for being active in issues section this week.

No problem!

I'm trying to implement a residual connection in my 3D-UNet-like architecture, but I keep getting RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time..

This is once again a (somewhat tricky) problem/limitation with the way the ReversibleModuleWrapper functions. As a workaround you could, for now, remove or disable the ReversibleModuleWrapper, for example, by setting the ReversibleModuleWrapper with disable=True. This should remove the error altogether. However, this disables any memory savings as well. Alternatively, you could just drop the shortcut connection, since I believe it is somewhat redundant if you use the AdditiveCoupling class. This line could then be removed:

out = out + down

The problem is in the way MemCNN currently implements the memory saving by calling backward in the backward hook, but I want to change this (see #37). The idea is to do so by replacing the backward hook with a custom autograd Function.

silvandeleemput commented 4 years ago

@ibro45 I have just released MemCNN 1.3.0 which should fix your issue and issue #37. Please give it a go and let me know if it works for you.

EDIT: From 1.3.0 and onward MemCNN should have support for implementing the residual layer as you did in the above example as well.

ibro45 commented 4 years ago

Hi @silvandeleemput, sorry for not getting back earlier and thank you for solving it that quick! I've been using it these days, it works perfectly in my GAN setup when the inputs are kept, but crashes when inputs are to be calculated during backprop. However, I don't think it is caused by the residual connection, so I will close this issue. As soon as I identify what's the cause of the current issue, I will let you know.

ibro45 commented 4 years ago

Hi Sil, I finally looked closer at the issue and it seems that residual connection does not work when inputs are not kept. As I mentioned before, it does work if the inputs are kept, so knowing that might help track down the cause of the problem.

You can reproduce it with this code:

import torch
from torch import nn, optim
import memcnn

class RevBlock(nn.Module):
    def __init__(self, nchan):
        super(RevBlock, self).__init__()

        invertible_module = memcnn.AdditiveCoupling(
            Fm=self.build_conv_block(nchan//2),
            Gm=self.build_conv_block(nchan//2)
        )

        self.rev_block = memcnn.InvertibleModuleWrapper(fn=invertible_module, 
                                                        keep_input=False, 
                                                        keep_input_inverse=False)

    def build_conv_block(self, nchan):
        return nn.Sequential(nn.Conv3d(nchan, nchan, kernel_size=5, padding=2),
                             nn.BatchNorm3d(nchan),
                             nn.PReLU(nchan))

    def forward(self, x, inverse=False):
        if inverse:
            return self.rev_block.inverse(x)
        else:
            return self.rev_block(x)

class DownTransition(nn.Module):
    def __init__(self, inChans, nConvs):
        super(DownTransition, self).__init__()
        outChans = 2*inChans
        self.down_conv_ab = self.build_down_conv(inChans, outChans)
        self.down_conv_ba = self.build_down_conv(inChans, outChans)
        self.core = nn.Sequential(*[RevBlock(outChans) for _ in range(nConvs)])
        self.relu = nn.PReLU(outChans)

    def build_down_conv(self, inChans, outChans):
        return nn.Sequential(nn.Conv3d(inChans, outChans, kernel_size=2, stride=2),
                             nn.BatchNorm3d(outChans),
                             nn.PReLU(outChans))

    def forward(self, x, inverse=False):
        if inverse:
            down_conv = self.down_conv_ba
            core = reversed(self.core)
        else:
            down_conv = self.down_conv_ab
            core = self.core

        down = down_conv(x)
        out = down
        for block in core:
            out = block(out, inverse=inverse)

        out = out + down
        return self.relu(out)

device = 'cuda:0'
model = DownTransition(16, 2).to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for i in range(10):
    print(i)
    optimizer.zero_grad()
    data, target = torch.rand((2,16,64,64,64)), torch.rand((2,32,32,32,32))
    data, target = data.to(device), target.to(device)
    out = model.forward(data)
    loss = criterion(out, target)
    loss.backward()
    optimizer.step()

Traceback 1:

Warning: Traceback of forward call that caused the error:
  File "issue_w_not_keep_inputs.py", line 76, in <module>
    loss.backward()
  File "/home/ft002207/anaconda3/envs/maastro140/lib/python3.7/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/ft002207/anaconda3/envs/maastro140/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
 (print_stack at /opt/conda/conda-bld/pytorch_1579022060824/work/torch/csrc/autograd/python_anomaly_mode.cpp:57)
Traceback (most recent call last):
  File "issue_w_not_keep_inputs.py", line 76, in <module>
    loss.backward()
  File "/home/ft002207/anaconda3/envs/maastro140/lib/python3.7/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/ft002207/anaconda3/envs/maastro140/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: CUDA error: an illegal memory access was encountered

Traceback 2: When ran with setting CUDA_LAUNCH_BLOCKING=1 the traceback shows that the problem is caused by the residual connection. However, these CUDA illegal memory access errors don't show why it happens.

Traceback (most recent call last):
  File "issue_w_not_keep_inputs.py", line 74, in <module>
    out = model.forward(data)
  File "issue_w_not_keep_inputs.py", line 58, in forward
    out = out + down
RuntimeError: CUDA error: an illegal memory access was encountered

Also, when out = out + down is removed, it works fine.

silvandeleemput commented 4 years ago

@ibro45 Hi, ok I think I figured out what the problem is. The problem resides in the DownTransition module class which does the following:

        out = down
        for block in core:
            out = block(out, inverse=inverse)

        out = out + down

What happens is that first a reference is placed to the input tensor called down which is initialized with values in memory as normal. Next, the loop over the cores will discard the associated memory for that referenced tensor. Finally, when adding the referenced (emptied) tensor to out it will crash because it cannot add these emptied tensor to the output tensor.

What you could do is to only pass keep_memory=True for the first element in the chain like so as to keep the data for the referenced tensor down around:

        out = down
        for i, block in enumerate(core):
            if i == 0:
                block.rev_block.keep_input = True
            out = block(out, inverse=inverse)

        out = out + down

This way, only the memory for the first feature maps will be kept in memory and the rest of the sequence of cores can be discarded.

ps. As a side note it is generally better practice to initialize the device object like this:

device = torch.device('cuda:0')

I hope this explanation helps!

ibro45 commented 4 years ago

Thank you so much Sil! Since I also perform InvertibleModuleWrapper's inverse() I had to modified what you suggested to:

out = down
for i, block in enumerate(core):
    if i == 0:
        if inverse:
            block.rev_block.keep_input_inverse = True
        else:
            block.rev_block.keep_input = True
out = out + down

Thanks for the explanation too!

device = torch.device('cuda:0')

Indeed, that was just me switching to GPU quickly since CPU would only give a segmentation fault...