Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 61 forks source link

fsdp(jit(...)) transform can use more memory compared to jit(fsdp(...)) #478

Open kshitij12345 opened 1 month ago

kshitij12345 commented 1 month ago

As fsdp(jit(...)) holds on to the original parameters as well as the sharded parameters, it can lead to higher memory usage. I think a work-around can be to initialize the original model on meta device. But if using meta is the only correct way then we should add a warning if user does otherwise.

import os
import torch
import torch.distributed as tdist
import thunder
import thunder.distributed

if __name__ == "__main__":
    tdist.init_process_group(backend="nccl")
    LOCAL_RANK = int(os.environ["LOCAL_RANK"])
    device = torch.device("cuda", LOCAL_RANK)
    torch.set_default_device(device)

    class Model(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.p1 = torch.nn.Parameter(torch.ones(1024, 1024))

        def forward(self, x):
            return self.p1 + x

    with device:
        model = Model()
        input_t = torch.randn(1)

    # jit(fsdp(...))
    # Memory Allocated - 6291968
    # model = thunder.distributed.fsdp(model)
    # model = thunder.jit(model, executors=["torch"])

    # fsdp(jit(...))
    # Memory Allocated - 10486272
    model = thunder.jit(model, executors=["torch"])
    model = thunder.distributed.fsdp(model)

    _ = model(input_t)

    if LOCAL_RANK == 0:
        print(torch.cuda.memory_allocated())

cc: @t-vi

cc @carmocca @awaelchli @crcrpar

t-vi commented 1 month ago

Indeed, and this is tricky:

mruberry commented 4 weeks ago

triage review:

crcrpar commented 2 weeks ago

@mruberry

  • we should be careful that retracing has the information needed (possibly observing original values) to work as expected

could you elaborate on what it means?

t-vi commented 2 weeks ago

Two parts:

I would like to see the solution to #483 / #564 enabling moving materialization out of the sharding and do it before we run the model and propagate data through what we have for #483 (which needs to deal with "has been moved to meta", too).