Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.14k stars 73 forks source link

FSDP2 & Thunder looks memory hungrier than `thunder.distributed.fsdp` for certain models #1176

Open crcrpar opened 1 day ago

crcrpar commented 1 day ago

Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.

🐛 Bug

Let's take stablecode-completion-alpha-3b whose sequence length (Config.block_size) is 16384,

torchrun --standalone --role rank --tee 3 --local-ranks-filter 0 --nproc-per-node 8 thunder/benchmarks/benchmark_litgpt.py --model_name stablecode-completion-alpha-3b --distributed_mode fsdp2 --shard_mode zero2 --compile thunder_inductor_cat_cudnn_dynamo

This command goes OOM while the same config (= model and compile) works with a single H100 with the memory usage of 77.02 GB.

For the sequence length of 16384

FSDP Impl Thunder Torch Compile Diff
thunder fsdp 67.74 #N/A #N/A
FSDP2 OOM 62.7 #N/A
FSDP1 #N/A 62.97 #N/A

For the sequence length of 8192

FSDP Impl Thunder Torch Compile Diff
thunder fsdp 37.58 #N/A #N/A
FSDP2 56.69 35.21 21.48
FSDP1 #N/A 35.24 #N/A

When --distributed_mode is "fsdp", then the benchmark script chooses thunder.distributed.fsdp for --compile of thunder w/o dynamo keyword, and FSDP1 for the others.

Clearly, FSDP2 & Thunder uses too much memory even compared to thunder's fsdp, while thunder's fsdp itslef seems to use more memory than Eager and Torch Compile. When I was on #940, I didn't see this trend of memory usage. Also, for Llama-3-8B, thunder still uses more memory but the gap is not that huge.

FSDP Impl Thunder Torch Compile Diff
thunder fsdp 75.79 #N/A #N/A
FSDP2 74.84 72.61 2.23
FSDP1 #N/A 73.41 #N/A

To Reproduce

Apply a diff like this and run commands like

torchrun --standalone --role rank --tee 3 --local-ranks-filter 0 thunder/benchmarks/benchmark_litgpt.py --model_name stablecode-completion-alpha-3b --warmup_iters 0 --max_iters 3 --compile eager --dump_memory_snapshot false --block_size 2048
@@ -227,6 +269,7 @@ class Benchmark_litGPT:
         fsdp_bucket_params: float | None = None,
         checkpoint_activations: bool = False,
         n_layers: int | None = None,
+        block_size: int | None = None,
         profiler_start: int = 15,
         profiler_stop: int = 15,
         skip_data_sync: bool = False,
@@ -360,6 +403,8 @@ class Benchmark_litGPT:

         if n_layers is not None:
             self.config.n_layer = n_layers
+        if block_size is not None:
+            self.config.block_size = block_size

         # Initialize the model
         t0 = time.perf_counter()

Code sample

Expected behavior

Environment

pjnl-20240919

Additional context

related to #1175

cc @carmocca @crcrpar

crcrpar commented 1 day ago

LayerNorm vs RMSNorm GptNeoxMLP vs LlamaMLP

fwiw, the mlp used in stablecode is not benchmarked as per #742

# stablecode-completion-alpha-3b
GPT(
  (lm_head): Linear(in_features=2560, out_features=49152, bias=False)
  (transformer): ModuleDict(
    (wte): Embedding(49152, 2560)
    (h): ModuleList(
      (0): Block(
        (norm_1): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (attn): Linear(in_features=2560, out_features=7680, bias=True)
          (proj): Linear(in_features=2560, out_features=2560, bias=True)
        )
        (post_attention_norm): Identity()
        (norm_2): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (mlp): GptNeoxMLP(
          (fc): Linear(in_features=2560, out_features=10240, bias=True)
          (proj): Linear(in_features=10240, out_features=2560, bias=True)
        )
        (post_mlp_norm): Identity()
      )
    )
    (ln_f): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
  )
)

# Llama-3-8B
GPT(
  (lm_head): Linear(in_features=4096, out_features=128256, bias=False)
  (transformer): ModuleDict(
    (wte): Embedding(128256, 4096)
    (h): ModuleList(
      (0): Block(
        (norm_1): RMSNorm()
        (attn): CausalSelfAttention(
          (attn): Linear(in_features=4096, out_features=6144, bias=False)
          (proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (post_attention_norm): Identity()
        (norm_2): RMSNorm()
        (mlp): LLaMAMLP(
          (fc_1): Linear(in_features=4096, out_features=14336, bias=False)
          (fc_2): Linear(in_features=4096, out_features=14336, bias=False)
          (proj): Linear(in_features=14336, out_features=4096, bias=False)
        )
        (post_mlp_norm): Identity()
      )
    )
    (ln_f): RMSNorm()
  )
)