Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k
stars
61
forks
source link
fsdp(jit(model)) + parameter sharing - dont duplicate allgather #602
Currently, for shared parameters - we AllGather (which creates a new tensor for its output) for each name of the shared parameter creating multiple copies during the execution. In this PR, we track the shared parameters and update the computation trace such that we only have one copy.
NOTE: This PR just handles fsdp(jit(model)) path. Will have a look at jit(fsdp(model)) seperately.
Eg.
import os
import torch
import torch.distributed as tdist
import thunder
import thunder.distributed
if __name__ == "__main__":
tdist.init_process_group(backend="nccl")
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
device = torch.device("cuda", LOCAL_RANK)
torch.set_default_device(device)
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(16, 16, bias=False)
self.fc2 = torch.nn.Linear(16, 16, bias=False)
def forward(self, x):
return self.fc1(x) + self.fc2(x)
with device:
model = Model()
# shared parameter
model.fc1.weight = model.fc2.weight
model = thunder.jit(model, executors=["torch"])
model = thunder.distributed.fsdp(model)
logits = model(torch.randn(4, 16, device=device))
if LOCAL_RANK == 0:
# print(torch.cuda.max_memory_allocated())
pro_trace = thunder.last_prologue_traces(model)[-1]
with open("weight_sharing_trace_pro.py", 'w') as f:
f.write(str(pro_trace))
trace = thunder.last_traces(model)[-1]
with open("weight_sharing_trace.py", 'w') as f:
f.write(str(trace))
bwd_trace = thunder.last_backward_traces(model)[-1]
with open("weight_sharing_bwd_trace.py", 'w') as f:
f.write(str(bwd_trace))
Computation Trace (After PR) - NOTE - we don't change the signature of the trace and it still takes all the parameters as input but we just don't utilize the shared parameters. It is ok, as they are backed by shared tensor and don't incur extra memory.
Ref: https://github.com/Lightning-AI/lightning-thunder/issues/257 and https://github.com/Lightning-AI/lightning-thunder/issues/257#issuecomment-2133070264
Currently, for shared parameters - we AllGather (which creates a new tensor for its output) for each name of the shared parameter creating multiple copies during the execution. In this PR, we track the shared parameters and update the computation trace such that we only have one copy.
NOTE: This PR just handles
fsdp(jit(model))
path. Will have a look atjit(fsdp(model))
seperately.Eg.
Computation Trace (Before PR)
Computation Trace (After PR) - NOTE - we don't change the signature of the trace and it still takes all the parameters as input but we just don't utilize the shared parameters. It is ok, as they are backed by shared tensor and don't incur extra memory.