pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.28k stars 115 forks source link

Make metrics logging work for pipeline parallelism #383

Closed wconstab closed 4 weeks ago

wconstab commented 4 weeks ago

Stack from ghstack (oldest at bottom):

Avoid complicating the ux and leave the status quo of 2 user-selectable behaviors:

Modify the meaning of 'log from rank 0' to log from rank 0 in non-pipeline parallel runs, and log from the local rank 0 within the last pipeline-parallel stage group if pp is enabled. (note: earlier pipeline stages still produce some metrics like mfu/memory, but do not compute loss.)

wconstab commented 4 weeks ago

We probably need to extend DeviceMesh to make calculating a specific rank easier.

yes, my thoughts exactly. I would like to discuss this offline. I couldn't quickly think of what the best API proposal for devicemesh would be so i went this route instead.