silvandeleemput / memcnn

PyTorch Framework for Developing Memory Efficient Deep Invertible Networks
MIT License
251 stars 26 forks source link

A question on implementation #71

Closed jambo6 closed 2 years ago

jambo6 commented 2 years ago

Hello,

Thanks a bunch for this package. I've had a go at implementing something similar following your code, I have one query about the implementation. It relates to the following lines

if not ctx.keep_input:
    # PyTorch 1.0+ way to clear storage
    inputs[0].storage()
    for element in inputs:
        element.storage().resize_(0)
...
# recompute input
with torch.no_grad():
    inputs_inverted = ctx.fn_inverse(*outputs)
    if not isinstance(inputs_inverted, tuple):
        inputs_inverted = (inputs_inverted,)
    for element_original, element_inverted in zip(inputs, inputs_inverted):
        element_original.storage().resize_(int(np.prod(element_original.size())))
        element_original.set_(element_inverted

Why do you go via this approach to free up the memory?

I thought it would have been sufficient to do something along the lines of

del inputs
torch.cuda.empty_cache()

and then simply reconstruct with

with torch.no_grad():
    inputs_inverted = ctx.fn_inverse(*outputs)

ignoring the storage().resize_ lines altogether. However the second approach doesn't work, whereas your approach frees the memory.

Could you enlighten me as to what exactly is going on in the first approach, and why it is required over something more straightforward where we simply free the inputs memory in forward and reconstruct in backwards?

Thanks!

silvandeleemput commented 2 years ago

Hi, thanks for your question. My guess is that when you call del inputs the reference to the tensor is deleted, but not the actual computation graph with the storage attached to the tensors. At least when performing a forward pass in training mode, where you require gradients, typically the inputs are retrained.

This is by design because otherwise, the following would invalidate creating graphs as well. That is when you create an input that is part of a computation graph, deleting the input will not work as shown in the following example:

import torch

device = torch.device('cuda')
print(torch.cuda.memory_allocated(device=device))

with torch.set_grad_enabled(True):  # this is by default, so same result if this is omitted
  inputs = torch.ones(4, 8, 16, device=device)
  print(torch.cuda.memory_allocated(device=device))
  output = torch.nn.Conv1d(in_channels=8, out_channels=8, kernel_size=1, device=device)(inputs)
  print(torch.cuda.memory_allocated(device=device))

# Here the memory can't be freed, since a reference is retained in memory by the computation graph (it is required for the backward pass)
del inputs
print(torch.cuda.memory_allocated(device=device))

Output:

0
2048
5120
5120

However, this should work, since the inputs are no longer required when the output is computed:

import torch

device = torch.device('cuda')
print(torch.cuda.memory_allocated(device=device))

with torch.no_grad():
  inputs = torch.ones(4, 8, 16, device=device)
  print(torch.cuda.memory_allocated(device=device))
  output = torch.nn.Conv1d(in_channels=8, out_channels=8, kernel_size=1, device=device)(inputs)
  print(torch.cuda.memory_allocated(device=device))

# Here the memory for the input can be freed since a reference is not needed for a backward pass by the computation graph 
del inputs
print(torch.cuda.memory_allocated(device=device))

Output:

0
2048
5120
2048

So sadly it isn't as simple as deleting some references to tensors when you also require gradient computations/network training.

MemCNN achieves memory savings during training by keeping the computation graph intact, but by setting the underlying linked storages of the input tensors to 0 during the forward pass with storage().resize_(0) and later restoring them to their original size and by copying the reconstructed input in when needed.

Alternative strategies could involve detaching the graph node tensors and subsequently deleting them, but then you'll lose the computation graph. Or you would have to write your own PyTorch C implementation for invertible operations.

jambo6 commented 2 years ago

Hello,

Sorry for the lack of reply on this but thanks a lot for this explanation, makes perfect sense.

Do you have much experience in monitoring the numerical errors in the backwards pass that arise due to floating point arithmetic? I tried a couple of networks and saw that numerical error does indeed appear and then seem to grow exponentially when compared with the true value of the gradients; however, I did not do any formal study.

If you have done any experiments, or have any thoughts on reducing numerical error, it would be appreciated!

silvandeleemput commented 2 years ago

I have done a few experiments on the numerical error of the reconstructed gradients w.r.t. the normal gradients and indeed I also found that the error increased exponentially with the number of invertible layers.

Also check figure 4 from the RevNet paper: https://arxiv.org/pdf/1707.04585.pdf. Note the section in the paper on the numerical error with a suggestion to combat it if necessary (with some additional overhead):

Numerical error. While Eqn. 8 reconstructs the activations exactly when done in exact arithmetic,
practical float32 implementations may accumulate numerical error during backprop. We study the
effect of numerical error in Section 5.2; while the error is noticeable in our experiments, it does not
significantly affect final performance. We note that if numerical error becomes a significant issue,
one could use fixed-point arithmetic on the x’s and y’s (but ordinary floating point to compute F and
G), analogously to [19]. In principle, this would enable exact reconstruction while introducing little
overhead, since the computation of the residual functions and their derivatives (which dominate the
computational cost) would be unchanged.
jambo6 commented 2 years ago

Thanks!