Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.12k stars 69 forks source link

Make DDP/FSDP a regular transform #122

Open t-vi opened 5 months ago

t-vi commented 5 months ago

🚀 Feature

Make DDP/FSDP a regular transform (to a large part including making transforms flexible enough to support this).

Motivation

Currently DDP/FSDP is not a regular transform, leading to things like #94 and limiting composability / sequencing. One of the key bits is that DDP/FSDP would need to do the adjustments we currently do to the prologue during tracing with DDP/FSDP in the transform, so we need to allow mutation of prologues through transforms. This is also in line with similar needs for other transforms (lora, quantization, but also value-and-grad-things) that change prologue signatures, so this generalization should happen.

cc @carmocca @awaelchli @crcrpar

IvanYashchuk commented 4 months ago

What is meant by making DDP/FSDP a regular transform? What are you planning to do? Today it's not a transform at all, as I commented here https://github.com/Lightning-AI/lightning-thunder/issues/94#issuecomment-2048269542. thunder.distributed.ddp/fsdp only annotate parameters for tracing. It's also described in the tutorial https://github.com/Lightning-AI/lightning-thunder/blob/main/notebooks/dev_tutorials/fsdp_tutorial.ipynb

I don't see any other way for sharding happen somewhen after the thunder.jit(model) call. What ideas do you have? The current workflow is

  1. Shard the model Done with thunder.distributed.fsdp(model) or with torch.distributed.FullyShardedDataParallel in PyTorch
  2. Set up the optimizer using the sharded model so that the optimizer state is a shard
  3. Call thunder.jit(sharded_model) or torch.compile(sharded_model) in PyTorch.
t-vi commented 4 months ago

I would like to move 3 up (for thunder.jit).

IvanYashchuk commented 4 months ago

Is the preferred order then 3 -> 1 -> 2?

t-vi commented 4 months ago

So per discussions with @crcrpar and @IvanYashchuk (thank you!)

(obviously good ideas from Masaki and Ivan, not so good ones my own)

mruberry commented 4 months ago

triage review — let's start design review with draft PR to discuss