pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.29k stars 115 forks source link

Make Transformer tolerate missing layers for PP #322

Closed wconstab closed 1 month ago

wconstab commented 1 month ago

Stack from ghstack (oldest at bottom):

A few small changes here lets manual PP frontend 'reconfigure' a whole transformer model to a stage's portion simply by setting undesired layers to None (in cases of top level layers) or deleting them from the ModuleDict (for 'layers.*').

These changes don't impact the FQNs of the remaining layers, which is critical for checkpoint load/save compatibility.

fegin commented 1 month ago

Nice. But it is less intuitive than I originally thought. Especially the int/str conversion part. Not sure if that's a best UX for pippy or a customized PipelineModuleList will be easier for users.

awgu commented 6 days ago

One downside to using ModuleDict is that now the model print does not collapse TransformerBlocks together, making the model print very long.