Summary
Previously CP forgot to shard the model via apply_fsdp when DP is not combined with CP. This leads to high peak memory usage and diverging loss.
Test
modify train_configs/llama3_8b.toml
steps = 20
context_parallel_degree = 8
run training on 8xH100 GPUs
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh
Before: CUDA OutOfMemory
After: successful 20-steps training
Stack from ghstack (oldest at bottom):
684
683
Summary Previously CP forgot to shard the model via
apply_fsdp
when DP is not combined with CP. This leads to high peak memory usage and diverging loss.Test
train_configs/llama3_8b.toml
CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh
Before: CUDA OutOfMemory After: successful 20-steps training