ucbrise / actnn

ActNN: Reducing Training Memory Footprint via 2-Bit Activation Compressed Training
MIT License
196 stars 30 forks source link

Cannot save memory during the FP #35

Closed tonyzhao-jt closed 2 years ago

tonyzhao-jt commented 2 years ago

Hi, part of my work now follows the project and I try to quantize the activation during the forward pass. However, i just noticed that though i can modify the ctx.saved_tensor to be my new compressed activation, the overall cuda memory occupation doesn't decrease even increase. What i found is that the original fp32 activation will not be freed and still be count as part of the cuda memory usage. I just wonder the reason behind this and seeks for a solution.

For your reference, what i did is like

def forward(self, input):
      qconv2d.apply(input, weight, .....)

Where inside the qconv2d, i did

input_int8, scale_inp = quantize_int8(input)
...
ctx.save_for_backward(input_int8, scale_inp, weight, bias)

However, using the torch.cuda.memory_allocated(0), the input & input_int8 will be both saved during the path?

Hoping for your reply.

merrymercy commented 2 years ago
  1. Try to reproduce our results and check the difference between your code and our code
  2. Make sure there is no other reference to your fp32 activation tensors, so they can be freed by the garbage collector.
tonyzhao-jt commented 2 years ago

Ok, till now I guess there are some model conversion problems in my own method of converting a fp model to its quantized version. I will close the issue once i found the way out.

tonyzhao-jt commented 2 years ago

FYI, i finally solved the problem by replacing all the non-gemm function to the version that provided by actnn. (ReLU, BatchNorm & Maxpool, etc.) There is a wierd phenomenon that, when using the model provided by the torchvision, torch seems to optimize its fp32 memory allocation during the forward propagation in these ops. :) Again, greatly appreciate the help you provided!