pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
2.25k stars 165 forks source link

Prepare train.py for model chunks for pipelining #406

Closed wconstab closed 3 months ago

wconstab commented 3 months ago

Stack from ghstack (oldest at bottom):

When using pipeline parallelism, a common technique for reducing bubble size is to use schedules that specify more than one model chunk per physical rank. e.g. pp degree 4 could have 8 pipeline stages, and rank 0 could have stage 0 and stage 4.

To generalize this concept without forking too much code in train.py, I make 'model_parts' a new container that either contains one model for non-PP or simple PP cases, and contains multiple model parts for complex PP cases.

In general, this is tractable becuase we treat each model part the same: we create one optimizer per model part, and one lr scheduler per optimizer. We apply spmd and compile individually to each model part. The general pattern is to loop over the model parts and perform an action on each part, which also works fine if the list size is 1.

The rest of train.py and optimizer/lr_scheduler changes add syntax sugar to simplify calling a method on each model part or optimizer part.

tianyu-l commented 3 months ago

another question: currently for PP do we do data loading on all the ranks? It seems we only need to do it on the first stage.

wconstab commented 3 months ago

another question: currently for PP do we do data loading on all the ranks? It seems we only need to do it on the first stage.

good catch, it'd be good to avoid data loading on later PP stages. i might make the fix in a follow up PR if you don't mind, to keep things separate.

wconstab commented 3 months ago

@tianyu-l I've cleaned up the simpler suggestions and still wondering what to do about the following.

another question: currently for PP do we do data loading on all the ranks? It seems we only need to do it on the first stage.

For dataloading, I opened #411 - let's continue discussion on that PR

CheckpointManager and ModelWrapper/OptimizerWrapper have different styles -- one takes a module or a list of modules, the other always takes a list. Shall we make them consistent?

LMK what you think about my question inline on this point. Shall I change the wrappers?

Instead of completely modifying these building blocks (e.g. build_optimizer), I wonder if it makes sense to create a class (call it UtilsBuilder, or Trainer, or something), which has methods build_optimizer, build_lr_scheduler, etc. Then PP just needs to create a subclass for it and be free to call same methods in super.

Do you think making these into a class/subclass is much different? One thing I could do is move the optimizer builder into a same file as lr_scheduler builder. (I'd have to do that anyway, if I made them into one class. Or i could leave the 'build_optimizer' function outside of the 'build_optimizers' function instead of defining it inside.