I observed GPU memory leaks when running OmniQuant's training loop. The numbers below are based on running LWC on the Llama-2-13b model with 128 calibration samples, batch 1, and 2048 sequence length.
GPU memory starts at ~15 gigs at the beginning of the first epoch of the first decoder layer. Memory consumption gradually increases as the calibration samples are processed. At the end of the first epoch, GPU memory peaks at ~20 gigs and stays until the end of the training loop.
Fix
The main difference is on line 227, where I called .detach().cpu() on the loss before keeping it into the list. This should allow the backward computational graph be cleaned up by the garbage collector. Nevertheless, I haven't figure out why it doesn't work when I drop the .cpu().
I also decorate the gradient norm calculation function with @torch.no_grad() to further reduce the memory footprint. The reduction is rather minor, though.
After the fix, running the same experiment setup yields a maximum GPU memory usage of ~15 gigs, which is almost the same as where it starts.
First of all, thanks for the great work!
Problem Description
I observed GPU memory leaks when running OmniQuant's training loop. The numbers below are based on running LWC on the Llama-2-13b model with 128 calibration samples, batch 1, and 2048 sequence length.
GPU memory starts at ~15 gigs at the beginning of the first epoch of the first decoder layer. Memory consumption gradually increases as the calibration samples are processed. At the end of the first epoch, GPU memory peaks at ~20 gigs and stays until the end of the training loop.
Fix
.detach().cpu()
on theloss
before keeping it into the list. This should allow the backward computational graph be cleaned up by the garbage collector. Nevertheless, I haven't figure out why it doesn't work when I drop the.cpu()
.@torch.no_grad()
to further reduce the memory footprint. The reduction is rather minor, though.After the fix, running the same experiment setup yields a maximum GPU memory usage of ~15 gigs, which is almost the same as where it starts.
Again, thanks for the work!