Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 61 forks source link

fsdp(jit(model)) + parameter sharing - dont duplicate allgather #602

Closed kshitij12345 closed 2 weeks ago

kshitij12345 commented 2 weeks ago

Ref: https://github.com/Lightning-AI/lightning-thunder/issues/257 and https://github.com/Lightning-AI/lightning-thunder/issues/257#issuecomment-2133070264

Currently, for shared parameters - we AllGather (which creates a new tensor for its output) for each name of the shared parameter creating multiple copies during the execution. In this PR, we track the shared parameters and update the computation trace such that we only have one copy.

NOTE: This PR just handles fsdp(jit(model)) path. Will have a look at jit(fsdp(model)) seperately.

Eg.

import os
import torch
import torch.distributed as tdist
import thunder
import thunder.distributed

if __name__ == "__main__":
    tdist.init_process_group(backend="nccl")
    LOCAL_RANK = int(os.environ["LOCAL_RANK"])
    device = torch.device("cuda", LOCAL_RANK)
    torch.set_default_device(device)

    class Model(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.fc1 = torch.nn.Linear(16, 16, bias=False)
            self.fc2 = torch.nn.Linear(16, 16, bias=False)

        def forward(self, x):
            return self.fc1(x) + self.fc2(x)

    with device:
        model = Model()

    # shared parameter
    model.fc1.weight = model.fc2.weight

    model = thunder.jit(model, executors=["torch"])
    model = thunder.distributed.fsdp(model)

    logits = model(torch.randn(4, 16, device=device))

    if LOCAL_RANK == 0:
        # print(torch.cuda.max_memory_allocated())
        pro_trace = thunder.last_prologue_traces(model)[-1]
        with open("weight_sharing_trace_pro.py", 'w') as f:
            f.write(str(pro_trace))
        trace = thunder.last_traces(model)[-1]
        with open("weight_sharing_trace.py", 'w') as f:
            f.write(str(trace))
        bwd_trace = thunder.last_backward_traces(model)[-1]
        with open("weight_sharing_bwd_trace.py", 'w') as f:
            f.write(str(bwd_trace))

Computation Trace (Before PR)

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(x, t_fc1_weight, t_fc2_weight):
  # x: "cuda:0 f32[4, 16]"
  # t_fc1_weight: "cuda:0 f32[8, 16]"
  # t_fc2_weight: "cuda:0 f32[8, 16]"
  p0 = torch_all_gather_prim_impl(t_fc1_weight, _torch_distributed_distributed_c10d_ProcessGroup_0, True, None)  # p0: "FUTURE cuda:0 f32[16, 16]"
  p3 = torch_all_gather_prim_impl(t_fc2_weight, _torch_distributed_distributed_c10d_ProcessGroup_0, True, None)  # p3: "FUTURE cuda:0 f32[16, 16]"
  t1 = torch_wait_prim_impl(p0)  # t1: "cuda:0 f32[16, 16]"
  del p0
  t2 = torch.nn.functional.linear(x, t1, None)  # t2: "cuda:0 f32[4, 16]"
    # t2 = ltorch.linear(x, t1, None)  # t2: "cuda:0 f32[4, 16]"
      # t2 = prims.linear(x, t1, None)  # t2: "cuda:0 f32[4, 16]"
  del t1
  t4 = torch_wait_prim_impl(p3)  # t4: "cuda:0 f32[16, 16]"
  del p3
  t5 = torch.nn.functional.linear(x, t4, None)  # t5: "cuda:0 f32[4, 16]"
    # t5 = ltorch.linear(x, t4, None)  # t5: "cuda:0 f32[4, 16]"
      # t5 = prims.linear(x, t4, None)  # t5: "cuda:0 f32[4, 16]"
  del t4
  t6 = torch.add(t2, t5)  # t6: "cuda:0 f32[4, 16]"
    # t6 = ltorch.add(t2, t5, alpha=None)  # t6: "cuda:0 f32[4, 16]"
      # t6 = prims.add(t2, t5)  # t6: "cuda:0 f32[4, 16]"
  del t2, t5
  return {'output': t6, 'flat_args': [x, t_fc1_weight, t_fc2_weight], 'flat_output': (t6,)}, ((x,), ())

Computation Trace (After PR) - NOTE - we don't change the signature of the trace and it still takes all the parameters as input but we just don't utilize the shared parameters. It is ok, as they are backed by shared tensor and don't incur extra memory.

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(x, t_fc1_weight, t_fc2_weight):
  # x: "cuda:0 f32[4, 16]"
  # t_fc1_weight: "cuda:0 f32[8, 16]"
  # t_fc2_weight: "cuda:0 f32[8, 16]"
  p0 = torch_all_gather_prim_impl(t_fc1_weight, _torch_distributed_distributed_c10d_ProcessGroup_0, True, None)  # p0: "FUTURE cuda:0 f32[16, 16]"
  t1 = torch_wait_prim_impl(p0)  # t1: "cuda:0 f32[16, 16]"
  del p0
  t2 = torch.nn.functional.linear(x, t1, None)  # t2: "cuda:0 f32[4, 16]"
    # t2 = ltorch.linear(x, t1, None)  # t2: "cuda:0 f32[4, 16]"
      # t2 = prims.linear(x, t1, None)  # t2: "cuda:0 f32[4, 16]"
  t3 = torch.nn.functional.linear(x, t1, None)  # t3: "cuda:0 f32[4, 16]"
    # t3 = ltorch.linear(x, t1, None)  # t3: "cuda:0 f32[4, 16]"
      # t3 = prims.linear(x, t1, None)  # t3: "cuda:0 f32[4, 16]"
  del t1
  t4 = torch.add(t2, t3)  # t4: "cuda:0 f32[4, 16]"
    # t4 = ltorch.add(t2, t3, alpha=None)  # t4: "cuda:0 f32[4, 16]"
      # t4 = prims.add(t2, t3)  # t4: "cuda:0 f32[4, 16]"
  del t2, t3
  return {'output': t4, 'flat_args': [x, t_fc1_weight, t_fc2_weight], 'flat_output': (t4,)}, ((x,), ())
t-vi commented 2 weeks ago

NOTE: This PR just handles fsdp(jit(model)) path. Will have a look at jit(fsdp(model)) seperately.

Thank you. Or we could try to get rid of jit(fsdp(model)) soonish. :)