Closed ibro45 closed 4 years ago
Apparently, this issue is present in PyTorch versions 1.3.1 and 1.4.0 but not in 1.1.0.
Interestingly enough, it is present for 1.1.0 when the above example is used in my GAN setup.
Hi @silvandeleemput , sorry for being active in issues section this week.
No problem!
I'm trying to implement a residual connection in my 3D-UNet-like architecture, but I keep getting
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
.
This is once again a (somewhat tricky) problem/limitation with the way the ReversibleModuleWrapper functions. As a workaround you could, for now, remove or disable the ReversibleModuleWrapper, for example, by setting the ReversibleModuleWrapper with disable=True
. This should remove the error altogether. However, this disables any memory savings as well. Alternatively, you could just drop the shortcut connection, since I believe it is somewhat redundant if you use the AdditiveCoupling class. This line could then be removed:
out = out + down
The problem is in the way MemCNN currently implements the memory saving by calling backward in the backward hook, but I want to change this (see #37). The idea is to do so by replacing the backward hook with a custom autograd Function.
@ibro45 I have just released MemCNN 1.3.0 which should fix your issue and issue #37. Please give it a go and let me know if it works for you.
EDIT: From 1.3.0 and onward MemCNN should have support for implementing the residual layer as you did in the above example as well.
Hi @silvandeleemput, sorry for not getting back earlier and thank you for solving it that quick! I've been using it these days, it works perfectly in my GAN setup when the inputs are kept, but crashes when inputs are to be calculated during backprop. However, I don't think it is caused by the residual connection, so I will close this issue. As soon as I identify what's the cause of the current issue, I will let you know.
Hi Sil, I finally looked closer at the issue and it seems that residual connection does not work when inputs are not kept. As I mentioned before, it does work if the inputs are kept, so knowing that might help track down the cause of the problem.
You can reproduce it with this code:
import torch
from torch import nn, optim
import memcnn
class RevBlock(nn.Module):
def __init__(self, nchan):
super(RevBlock, self).__init__()
invertible_module = memcnn.AdditiveCoupling(
Fm=self.build_conv_block(nchan//2),
Gm=self.build_conv_block(nchan//2)
)
self.rev_block = memcnn.InvertibleModuleWrapper(fn=invertible_module,
keep_input=False,
keep_input_inverse=False)
def build_conv_block(self, nchan):
return nn.Sequential(nn.Conv3d(nchan, nchan, kernel_size=5, padding=2),
nn.BatchNorm3d(nchan),
nn.PReLU(nchan))
def forward(self, x, inverse=False):
if inverse:
return self.rev_block.inverse(x)
else:
return self.rev_block(x)
class DownTransition(nn.Module):
def __init__(self, inChans, nConvs):
super(DownTransition, self).__init__()
outChans = 2*inChans
self.down_conv_ab = self.build_down_conv(inChans, outChans)
self.down_conv_ba = self.build_down_conv(inChans, outChans)
self.core = nn.Sequential(*[RevBlock(outChans) for _ in range(nConvs)])
self.relu = nn.PReLU(outChans)
def build_down_conv(self, inChans, outChans):
return nn.Sequential(nn.Conv3d(inChans, outChans, kernel_size=2, stride=2),
nn.BatchNorm3d(outChans),
nn.PReLU(outChans))
def forward(self, x, inverse=False):
if inverse:
down_conv = self.down_conv_ba
core = reversed(self.core)
else:
down_conv = self.down_conv_ab
core = self.core
down = down_conv(x)
out = down
for block in core:
out = block(out, inverse=inverse)
out = out + down
return self.relu(out)
device = 'cuda:0'
model = DownTransition(16, 2).to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for i in range(10):
print(i)
optimizer.zero_grad()
data, target = torch.rand((2,16,64,64,64)), torch.rand((2,32,32,32,32))
data, target = data.to(device), target.to(device)
out = model.forward(data)
loss = criterion(out, target)
loss.backward()
optimizer.step()
Traceback 1:
Warning: Traceback of forward call that caused the error:
File "issue_w_not_keep_inputs.py", line 76, in <module>
loss.backward()
File "/home/ft002207/anaconda3/envs/maastro140/lib/python3.7/site-packages/torch/tensor.py", line 195, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/ft002207/anaconda3/envs/maastro140/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
allow_unreachable=True) # allow_unreachable flag
(print_stack at /opt/conda/conda-bld/pytorch_1579022060824/work/torch/csrc/autograd/python_anomaly_mode.cpp:57)
Traceback (most recent call last):
File "issue_w_not_keep_inputs.py", line 76, in <module>
loss.backward()
File "/home/ft002207/anaconda3/envs/maastro140/lib/python3.7/site-packages/torch/tensor.py", line 195, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/ft002207/anaconda3/envs/maastro140/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: CUDA error: an illegal memory access was encountered
Traceback 2:
When ran with setting CUDA_LAUNCH_BLOCKING=1
the traceback shows that the problem is caused by the residual connection. However, these CUDA illegal memory access errors don't show why it happens.
Traceback (most recent call last):
File "issue_w_not_keep_inputs.py", line 74, in <module>
out = model.forward(data)
File "issue_w_not_keep_inputs.py", line 58, in forward
out = out + down
RuntimeError: CUDA error: an illegal memory access was encountered
Also, when out = out + down
is removed, it works fine.
@ibro45 Hi, ok I think I figured out what the problem is. The problem resides in the DownTransition module class which does the following:
out = down
for block in core:
out = block(out, inverse=inverse)
out = out + down
What happens is that first a reference is placed to the input tensor called down
which is initialized with values in memory as normal. Next, the loop over the cores will discard the associated memory for that referenced tensor. Finally, when adding the referenced (emptied) tensor to out
it will crash because it cannot add these emptied tensor to the output tensor.
What you could do is to only pass keep_memory=True for the first element in the chain like so as to keep the data for the referenced tensor down around:
out = down
for i, block in enumerate(core):
if i == 0:
block.rev_block.keep_input = True
out = block(out, inverse=inverse)
out = out + down
This way, only the memory for the first feature maps will be kept in memory and the rest of the sequence of cores can be discarded.
ps. As a side note it is generally better practice to initialize the device object like this:
device = torch.device('cuda:0')
I hope this explanation helps!
Thank you so much Sil!
Since I also perform InvertibleModuleWrapper
's inverse()
I had to modified what you suggested to:
out = down
for i, block in enumerate(core):
if i == 0:
if inverse:
block.rev_block.keep_input_inverse = True
else:
block.rev_block.keep_input = True
out = out + down
Thanks for the explanation too!
device = torch.device('cuda:0')
Indeed, that was just me switching to GPU quickly since CPU would only give a segmentation fault...
Description
Hi @silvandeleemput , sorry for being active in issues section this week. I'm trying to implement a residual connection in my 3D-UNet-like architecture, but I keep getting
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
. I've also asked about this issue on PyTorch forum here.What I Did
I made this minimal example of the issue I have. I've been trying to debug it but I can't see anything wrong in the code I've written so far, so I am wondering if you have encountered anything similar while developing MemCNN because I am not sure where the issue is coming from. The line where the problem starts is
out = down
in theDownTransition
class.Minimal Example
Traceback