Before this PR, we would run out of memory quantizing the weights of Llama 3 70B on two H100 80GB GPUs.
This is because as we were quantizing the weights, we were holding references to the original Linear modules such that PyTorch wouldn't free everything. Now we explicitly clone the weights and biases to then delete all of the original Parameters, as we quantize each module. This seems to massively improve peak memory usage and essentially makes it not an issue beyond the initial unquantized checkpoint load.
Before this PR, we would run out of memory quantizing the weights of Llama 3 70B on two H100 80GB GPUs.
This is because as we were quantizing the weights, we were holding references to the original Linear modules such that PyTorch wouldn't free everything. Now we explicitly clone the weights and biases to then delete all of the original Parameters, as we quantize each module. This seems to massively improve peak memory usage and essentially makes it not an issue beyond the initial unquantized checkpoint load.