Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given
batch.
Changed: instead of slicing self.freqs_cis down to seqlen at top level
transformer based on the input token shape, we slice it down to seqlen
inside a transformer layer after we have re-expanded to the full seqlen
in cases where TP has sharded across seqlen.
In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but
we do not generally know this. That makes it hard for stage1 to slice
freqs_cis correctly. It's easy to do the slicing deeper inside, since
at that point we do know the full seqlen unambiguously.
Note: the full self.freqs_cis is stored in memory either way, and the
thing passed into every layer is just a view. This change should not be
material for memory usage or otherwise.
Stack from ghstack (oldest at bottom):
318
322
Unchanged: we precompute freqs_cis for max_seqlen, >> seqlen for a given batch.
Changed: instead of slicing self.freqs_cis down to seqlen at top level transformer based on the input token shape, we slice it down to seqlen inside a transformer layer after we have re-expanded to the full seqlen in cases where TP has sharded across seqlen.
In the PP case, stage 1's input may be seqlen/TP instead of seqlen, but we do not generally know this. That makes it hard for stage1 to slice freqs_cis correctly. It's easy to do the slicing deeper inside, since at that point we do know the full seqlen unambiguously.
Note: the full self.freqs_cis is stored in memory either way, and the thing passed into every layer is just a view. This change should not be material for memory usage or otherwise.