Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 61 forks source link

thunder.examine.memory_calculation.get_alloc_memory overestimates the peak allocated memory on Colab #677

Open IvanYashchuk opened 3 days ago

IvanYashchuk commented 3 days ago

🐛 Bug

get_alloc_memory has a potential to be a great tool for estimating the effect on memory usage after transforming traces. For those who don't know about this function it lives here: https://github.com/Lightning-AI/lightning-thunder/blob/8c5905fd1a93145e690791a7c7a3c3e10b16b32b/thunder/examine/memory_caculation.py#L120-L137

I've tried generating the execution trace with Fake CUDA Tensors so that it's possible to analyze the final trace without the actual execution even on low-memory GPUs like in Colab:

from litgpt.config import Config
from litgpt import GPT
import thunder
import torch

from torch._subclasses.fake_tensor import FakeTensorMode
fake_mode = FakeTensorMode()

config = Config.from_name("Llama-3-8B")

with fake_mode, torch.device("cuda"):
    model = GPT(config).to(torch.bfloat16)

with fake_mode:
    micro_batch_size = 1
    x = torch.randint(0, (micro_batch_size, model.max_seq_length,), device="cuda")
cmodel = thunder.jit(model)

cache_entry, inps, pro_to_epi = thunder.compile_data(cmodel).get_computation_and_inputs(x)

The estimated peak memory for this trace is for some reason different on Colab and locally

import thunder.examine.memory_caculation
estimated_peak = thunder.examine.memory_caculation.get_alloc_memory(cache_entry.computation_traces[-1])[0]
print(f"Estimated peak memory usage in GiB: {estimated_peak / (1024 ** 3)}")
print(f"Estimated peak memory usage in GB: {estimated_peak / (1000 ** 3)}")

on Colab I see

Estimated peak memory usage in GiB: 185.74466705322266
Estimated peak memory usage in GB: 199.4418176

and locally:

Estimated peak memory usage in GiB: 56.76426076889038
Estimated peak memory usage in GB: 60.950160896

On Colab litgpt, nvfuser, and Thunder are installed with:

!pip install litgpt
!pip install nvfuser-cu121-torch23
!pip install git+https://github.com/Lightning-AI/lightning-thunder@main

cc @apaz-cli

IvanYashchuk commented 3 days ago

Oh, maybe on Colab with T4 GPU Thunder decides not to use Flash Attention and decomposes the SDPA call...

IvanYashchuk commented 3 days ago

Oh, maybe on Colab with T4 GPU Thunder decides not to use Flash Attention and decomposes the SDPA call...

Right, there is no sdpafx_grad_forward_scaled_dot_product_efficient_attention in the trace. So Colab's estimate makes sense with the decomposition Thunder produces. What can we do to ask Thunder to produce the same trace as would be on H100 when running on T4 or with no GPU at all?

lantiga commented 7 hours ago

Thank would be very cool. I guess one would need to override some behavior in the checkers.

Right now the claim made by checkers is twofold:

We probably need to decouple the two, so that only the first check is done, irrespective of the second. It could be a flag we pass or something else.

t-vi commented 3 hours ago

To my mind, it would be feasible to include some form of "assume this hardware" in compile data and then divert queries for hardware properties to that, falling back to real hardware.

lantiga commented 2 hours ago

It would be great to have a dict that we can set from the outside that describes the hardware capabilities, so backends can use that in the checker functions.

In general, the more we avoid relying on the actual underlying hardware to reason about the computation, the better it is. This should be an explicit goal.

IvanYashchuk commented 1 minute ago

I agree that the checkers should be more controllable and there should be a way to describe the target hardware. It could also be useful in a future export scenario. Let's create a separate issue to track accomplishing this goal.

For this particular operation, it's even more complicated because we redirect to PyTorch to tell us what scaled_dot_product_attention to use https://github.com/Lightning-AI/lightning-thunder/blob/5fc67dcba844554c8a2390ef8775594e61f18737/thunder/executors/sdpaex.py#L654-L661 However, it shouldn't be too hard to force the selection of a particular backend in the checker https://github.com/Lightning-AI/lightning-thunder/blob/5fc67dcba844554c8a2390ef8775594e61f18737/thunder/executors/sdpaex.py#L683-L699