davmacario / MDI-LLM

Implementation of Model-Distributed Inference for Large Language Models, built on top of LitGPT
MIT License
4 stars 2 forks source link

Add support for model-parallel training #28

Open davmacario opened 7 months ago

davmacario commented 7 months ago

The main limitation of LLMs is the huge model size, plus, during training, the required VRAM/RAM necessary to store the model + the backpropagation parameters are much higher than during inference. As a result, it is possible to perform inference with GPT-2 XL on a single Nvidia GTX 1080 Ti (11 GB VRAM), but not training.

On a multi-GPU system, DistributedDataParallel does not solve the issue, as it still requires each device to fit the whole model, as only data is parallelized.

To use model parallelism, a possible working approach is to create a new class (ModelParallelGPT) that inherits from the original model class (GPT), but assigns a different piece of model to a different device on the host. This trivial partition is far less efficient than MDI at the inference stage (as it does not allow for pipelining - only one GPU at a time is active), but it is the only way to train the model.

Another possible approach would be to check out PiPPy 👀.

davmacario commented 7 months ago

See model-parallel branch.