pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
2.25k stars 165 forks source link

Skip data loading for middle PP ranks #411

Closed wconstab closed 3 months ago

wconstab commented 3 months ago

Stack from ghstack (oldest at bottom):

First and last PP rank need to perform data loading to fetch matching input_ids and labels.

A downside to skipping dataloading for middle ranks is added complexity in train.py including handling metrics.

wconstab commented 3 months ago

honestly i'm not sure if we want to land this PR or not. It is not urgent in any case, and we could do some experiments to decide whether its more critical to reduce data-loader stress or to keep compute/comms balanced per rank and ensure we avoid timeouts. closing for now.

cc @tianyu-l @wanchaol @awgu