Open IvanYashchuk opened 1 month ago
Based on my understanding, activation checkpointing is useful when recomputing the output/activation is relatively cheap. But for compute bound operations (matmuls), I think it may be useful to also have CPU Offloading of these intermediate activations (if CPU-GPU memory transfer time is less than recomputing). I think the ideal approach will involve activation checkpointing taking care of operations where recompute is cheaper and CPU offloading taking care of operation where compute is expensive (and loading from memory is cheaper). Curious to know your thoughts :)
Yes, the ideal approach would use a mix of offloading and recomputation. To do this automatically in the future we need to have tools for doing this manually first. The initial work of activation recomputation will enable Thunder to work with larger models, and yes recomputing matmuls is expensive but that's what needs to be done to save on memory. Once we can run those models we'll switch to improving performance by recomputing less.
🚀 Feature
In Thunder, we have one memory-saving optimization called
rematerialization
but it's applicable only for fusion regions and does nothing for computation outside the fusion regions. Rematerialization is not only a memory-saving optimization but also performance-improving because the general expectation the less number of global memory reads and write a fused region needs to do the better performance is even though more computation is introduced into the fused region.However, having just this memory-saving pass is not enough for large context window or large parameter count models. The structure of backward computation requires additional information about the intermediates of the forward computation, this is usually referred to as “activations” or in PyTorch code “saved for backward tensors”. The requirement to store these tensors in GPU memory contributes to the large memory consumption peaks. Here’s an example of how the total size of saved for backward tensors scales with the number of layers for the Thunder-compiled LitGPT model with “longchat-7b-16k” configuration: 67% of the total memory is occupied by saved-for-backward tensors.
PyTorch implements activation checkpointing via a special
torch.autograd.Function
that saves the required inputs in the detached form for recomputation in backward and performs the forward computation underno_grad
context. Backward uses the stashed inputs and the forward function to restore missing tensors, which are used to invoke the original backward function. Thunder doesn’t support tracingtorch.autograd.Function.backward
(https://github.com/Lightning-AI/lightning-thunder/issues/666), it also doesn’t support ano_grad
context manager inside the traced function, it doesn’t recognize PyTorch’s Autograd graph manipulations likedetach()
andrequires_grad_(True/False)
.PyTorch requires users to use special markers to apply saved tensor recomputation. In Thunder, we can do this automatically.
Thunder’s advantage is that we control both forward and backward computations and we can implement a trace transformation that would recompute “saved for backward tensors”.
There are two places where we could insert the logic for tensor recomputation:
I think the first option is better and less intrusive in the current compilation pipeline.