Open medivh-xp opened 4 days ago
cc: @XilunWu
The PR #592 enables CP in torchtitan. You can change context_parallel_degree (for example 8 for Cp8) in the toml file. See detail in the PR description.
CP8 is enough for 128K on H100 and A100. If you still encounter OOM, you can change selective checkpoint to "full" to further reduce peak memory usage.
cc: @lessw2020
The PR #592 enables CP in torchtitan. You can change context_parallel_degree (for example 8 for Cp8) in the toml file. See detail in the PR description.
CP8 is enough for 128K on H100 and A100. If you still encounter OOM, you can change selective checkpoint to "full" to further reduce peak memory usage.
@XilunWu Thank you for your reply! I noticed that in PR #467, the activation values are reduced through activations offload. If a balance can be struck among computation, memory, and H2D bandwidth, it seems that Full-AC might not be necessary (I'm not sure if my understanding is correct. Full-AC recomputation will significantly reduce the MFU). So how should I choose between full-AC and activations offload? It seems that activations offload could theoretically achieve a higher MFU?
@awgu can you share a bit more on the status of the activation offloading PR? E.g. is it ready to be used, and its performance vs. using full AC on llama models.
The PR is meant as a way to add activation offloading to your model with intrusive changes. The main concern is that for current gen Nvidia GPUs, the offloading may contend with inter-node collectives for PCIe bandwidth.
If you apply full activation checkpointing to each transformer block and then further apply activation offloading to the transformer block input, then you can accumulate no extra GPU memory per transformer block, which can help unblock long-sequence use cases.
There probably needs to be some extra work on the PR for that though.
Under the 128k long sequence, the activation value memory increases significantly. CP8 + TP8 seems necessary (they reduce the activation value memory almost linearly), but there is still as much as 50G of activation value memory. Reccompute the activations of the MLP can reduce it by about 9G, while the recalculation of the ATTENTION layer or MLP up linear seems rather costly.I noticed that the article at https://arxiv.org/pdf/2410.06511 mentioned Full checkpoint was applied to address the activation memory issue,which seems to significantly increase the execution time of recomputation? Does TorchTitan plan to offload the activation values and reload them during the backward calculation to reduce the activation value memory?