Closed jambo6 closed 2 years ago
Hi, thanks for your question. My guess is that when you call del inputs
the reference to the tensor is deleted, but not the actual computation graph with the storage attached to the tensors. At least when performing a forward pass in training mode, where you require gradients, typically the inputs are retrained.
This is by design because otherwise, the following would invalidate creating graphs as well. That is when you create an input that is part of a computation graph, deleting the input will not work as shown in the following example:
import torch
device = torch.device('cuda')
print(torch.cuda.memory_allocated(device=device))
with torch.set_grad_enabled(True): # this is by default, so same result if this is omitted
inputs = torch.ones(4, 8, 16, device=device)
print(torch.cuda.memory_allocated(device=device))
output = torch.nn.Conv1d(in_channels=8, out_channels=8, kernel_size=1, device=device)(inputs)
print(torch.cuda.memory_allocated(device=device))
# Here the memory can't be freed, since a reference is retained in memory by the computation graph (it is required for the backward pass)
del inputs
print(torch.cuda.memory_allocated(device=device))
Output:
0
2048
5120
5120
However, this should work, since the inputs are no longer required when the output is computed:
import torch
device = torch.device('cuda')
print(torch.cuda.memory_allocated(device=device))
with torch.no_grad():
inputs = torch.ones(4, 8, 16, device=device)
print(torch.cuda.memory_allocated(device=device))
output = torch.nn.Conv1d(in_channels=8, out_channels=8, kernel_size=1, device=device)(inputs)
print(torch.cuda.memory_allocated(device=device))
# Here the memory for the input can be freed since a reference is not needed for a backward pass by the computation graph
del inputs
print(torch.cuda.memory_allocated(device=device))
Output:
0
2048
5120
2048
So sadly it isn't as simple as deleting some references to tensors when you also require gradient computations/network training.
MemCNN achieves memory savings during training by keeping the computation graph intact, but by setting the underlying linked storages of the input tensors to 0 during the forward pass with storage().resize_(0)
and later restoring them to their original size and by copying the reconstructed input in when needed.
Alternative strategies could involve detaching the graph node tensors and subsequently deleting them, but then you'll lose the computation graph. Or you would have to write your own PyTorch C implementation for invertible operations.
Hello,
Sorry for the lack of reply on this but thanks a lot for this explanation, makes perfect sense.
Do you have much experience in monitoring the numerical errors in the backwards pass that arise due to floating point arithmetic? I tried a couple of networks and saw that numerical error does indeed appear and then seem to grow exponentially when compared with the true value of the gradients; however, I did not do any formal study.
If you have done any experiments, or have any thoughts on reducing numerical error, it would be appreciated!
I have done a few experiments on the numerical error of the reconstructed gradients w.r.t. the normal gradients and indeed I also found that the error increased exponentially with the number of invertible layers.
Also check figure 4 from the RevNet paper: https://arxiv.org/pdf/1707.04585.pdf. Note the section in the paper on the numerical error with a suggestion to combat it if necessary (with some additional overhead):
Numerical error. While Eqn. 8 reconstructs the activations exactly when done in exact arithmetic,
practical float32 implementations may accumulate numerical error during backprop. We study the
effect of numerical error in Section 5.2; while the error is noticeable in our experiments, it does not
significantly affect final performance. We note that if numerical error becomes a significant issue,
one could use fixed-point arithmetic on the x’s and y’s (but ordinary floating point to compute F and
G), analogously to [19]. In principle, this would enable exact reconstruction while introducing little
overhead, since the computation of the residual functions and their derivatives (which dominate the
computational cost) would be unchanged.
Thanks!
Hello,
Thanks a bunch for this package. I've had a go at implementing something similar following your code, I have one query about the implementation. It relates to the following lines
Why do you go via this approach to free up the memory?
I thought it would have been sufficient to do something along the lines of
and then simply reconstruct with
ignoring the
storage().resize_
lines altogether. However the second approach doesn't work, whereas your approach frees the memory.Could you enlighten me as to what exactly is going on in the first approach, and why it is required over something more straightforward where we simply free the
inputs
memory inforward
and reconstruct inbackwards
?Thanks!