Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.13k stars 69 forks source link

Recomputation of saved for backward tensors #871

Open IvanYashchuk opened 1 month ago

IvanYashchuk commented 1 month ago

🚀 Feature

In Thunder, we have one memory-saving optimization called rematerialization but it's applicable only for fusion regions and does nothing for computation outside the fusion regions. Rematerialization is not only a memory-saving optimization but also performance-improving because the general expectation the less number of global memory reads and write a fused region needs to do the better performance is even though more computation is introduced into the fused region.

However, having just this memory-saving pass is not enough for large context window or large parameter count models. The structure of backward computation requires additional information about the intermediates of the forward computation, this is usually referred to as “activations” or in PyTorch code “saved for backward tensors”. The requirement to store these tensors in GPU memory contributes to the large memory consumption peaks. Here’s an example of how the total size of saved for backward tensors scales with the number of layers for the Thunder-compiled LitGPT model with “longchat-7b-16k” configuration: saved-tensor-usage-vs-n_layers-longchat-7b-16k 67% of the total memory is occupied by saved-for-backward tensors.

PyTorch implements activation checkpointing via a special torch.autograd.Function that saves the required inputs in the detached form for recomputation in backward and performs the forward computation under no_grad context. Backward uses the stashed inputs and the forward function to restore missing tensors, which are used to invoke the original backward function. Thunder doesn’t support tracing torch.autograd.Function.backward (https://github.com/Lightning-AI/lightning-thunder/issues/666), it also doesn’t support a no_grad context manager inside the traced function, it doesn’t recognize PyTorch’s Autograd graph manipulations like detach() and requires_grad_(True/False).

PyTorch requires users to use special markers to apply saved tensor recomputation. In Thunder, we can do this automatically.

Thunder’s advantage is that we control both forward and backward computations and we can implement a trace transformation that would recompute “saved for backward tensors”.

There are two places where we could insert the logic for tensor recomputation:

  1. Right after the forward and backward traces are created and before the “transform for execution” phase. The advantage of this is that there’s an opportunity for fusing recomputed regions with the rest of the backward regions. However, the memory-saving effect is difficult to estimate because it’s the “pre-fusion” phase. In code: https://github.com/Lightning-AI/lightning-thunder/blob/7e1f9bd95c29b9642173088e4d5c86c84db434cc/thunder/executors/torch_autograd.py#L136
  2. After the “transform for execution” and “fusion rematerialization” phases. At this point we have a ready-for-execution Python function and any memory-saving estimates will be more accurate. There may be missing fusion opportunities. In code: https://github.com/Lightning-AI/lightning-thunder/blob/7e1f9bd95c29b9642173088e4d5c86c84db434cc/thunder/executors/torch_autograd.py#L213

I think the first option is better and less intrusive in the current compilation pipeline.

kshitij12345 commented 1 month ago

Based on my understanding, activation checkpointing is useful when recomputing the output/activation is relatively cheap. But for compute bound operations (matmuls), I think it may be useful to also have CPU Offloading of these intermediate activations (if CPU-GPU memory transfer time is less than recomputing). I think the ideal approach will involve activation checkpointing taking care of operations where recompute is cheaper and CPU offloading taking care of operation where compute is expensive (and loading from memory is cheaper). Curious to know your thoughts :)

IvanYashchuk commented 1 month ago

Yes, the ideal approach would use a mix of offloading and recomputation. To do this automatically in the future we need to have tools for doing this manually first. The initial work of activation recomputation will enable Thunder to work with larger models, and yes recomputing matmuls is expensive but that's what needs to be done to save on memory. Once we can run those models we'll switch to improving performance by recomputing less.