pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.29k stars 115 forks source link

Refactor freqs_cis slice to be safer for PP #321

Closed wconstab closed 1 month ago

wconstab commented 1 month ago

Stack from ghstack (oldest at bottom):

Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given batch.

Changed: instead of slicing self.freqs_cis down to seqlen at top level transformer based on the input token shape, we slice it down to seqlen inside a transformer layer after we have re-expanded to the full seqlen in cases where TP has sharded across seqlen.

In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but we do not generally know this. That makes it hard for stage1 to slice freqs_cis correctly. It's easy to do the slicing deeper inside, since at that point we do know the full seqlen unambiguously.

Note: the full self.freqs_cis is stored in memory either way, and the thing passed into every layer is just a view. This change should not be material for memory usage or otherwise.