pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
2.68k stars 212 forks source link

[cp] apply fsdp to model when CP is enabled without DP for correct loss and lower mem usage #685

Open XilunWu opened 1 week ago

XilunWu commented 1 week ago

Stack from ghstack (oldest at bottom):

Summary Previously CP forgot to shard the model via apply_fsdp when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

Test

  1. modify train_configs/llama3_8b.toml
    steps = 20
    context_parallel_degree = 8
  2. run training on 8xH100 GPUs CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh Before: CUDA OutOfMemory After: successful 20-steps training