microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.01k stars 4.06k forks source link

How to use one's existing vertical MP with DeepSpeed #674

Open stas00 opened 3 years ago

stas00 commented 3 years ago

The docs mention in several places that DeepSpeed can be used with a model that implements its own MP - which my guess it's referring to a vertical model slicing as taught at the pytorch tutorial., that is groups of layers spread out through several GPUs.

OK, we have implemented vertical slicing MP for t5, gpt2 and bart in transformers, using a naive slow approach, but I have no idea how to bolt DeepSpeed on top of such a model that spans several GPUs. The typical problem is that there is a conflict between the model doing its own device management and switching model's data and inputs to different devices, with any external solutions. e.g. I can't do DP or DDP over 2 GPUs which are also used for vertical MP.

I do find it unclear though when in some places the documentation talks about allowing one's own MP and no explanation or examples of how exactly that would work. I think it does suggest to look at Megatron-LM for such an example, but it's a complex project. What would have helped a lot is having a simple example on how to bolt DeepSpeed for an MP-enabled model with perhaps just a few layers, such as the simple toy MP model from the same pytorch tutorial.

Thank you!

And I also understand that the result of combing one's own vertical MP with DeepSpeed might actually be worse than just using the ZeRO Partitioning, and probably the whole setup will be much more complicated, but it's hard to appreciate or compare things if there is a barrier to trying various options.

My gut feeling, which is not yet supported by enough experience, is telling me that DeepSpeed's solution to the memory issue completely eliminates the need for vertical MP (almost definitely once stage 3 is implemented), but since our initiative to implement MP in transformers is sort of incomplete, since the naive MP doesn't scale, I'm trying to either find a way to make it efficient or perhaps remove it altogether if DeepSpeed is all one needs in all possible scenarios.

I'm trying to figure out if bolting just PP onto the naive MP will do the trick. But it's tricky since one needs to "sequentialize" the model's stack for PP to work.

samyam commented 3 years ago

@stas00 We do support both horizontal MP and vertical MP (Pipeline) with DeepSpeed in addition to ZeRO. In fact you can use all three together, but as you noticed we do not have an example released yet, but its in our to do list.

But for the problem you are facing, we do have a Pipeline Parallelism example that might be helpful to get the tp5, gp2 , and bart that you mentioned going :

https://github.com/microsoft/DeepSpeedExamples/tree/master/pipeline_parallelism

https://github.com/microsoft/DeepSpeedExamples/blob/400cd1bc3524507301f39c659d3069672c4ab464/pipeline_parallelism/train.py#L120

Following this example, you should be able to combine pipeline parallelism with data parallelism or ZeRO Stage 1 in DeepSpeed. This uses DeepSpeed's own pipeline parallelism, but if your model has been sequentialized, DeepSpeed can do the rest for you. @ShadenSmith knows the details of PP in DeepSpeed better than I do so tagging him here.

If you want to use your own implementation of pipeline parallelism or tensor slicing with DeepSpeed, you can do that too, but that requires you to implement a model parallel unit (mpu) object, and pass it to DeepSpeed. This mpu object should implement a few methods that tells DeepSpeed which GPUs are involved in model parallelism and which are involved in data parallelism.

How to pass mpu to Deepspeed: https://github.com/microsoft/DeepSpeedExamples/blob/400cd1bc3524507301f39c659d3069672c4ab464/Megatron-LM/pretrain_gpt2.py#L172

What should mpu implement

mpu.get_model_parallel_group() mpu.get_data_parallel_group() mpu.get_model_parallel_rank() mpu.get_data_parallel_rank() mpu.get_data_parallel_world_size() mpu.get_model_parallel_world_size()

https://github.com/microsoft/DeepSpeed/blob/82cecf69c357af5946f0548d2516789cfe2d6933/deepspeed/runtime/engine.py#L520

In terms of if we ever need Pipeline Parallelism or Model Parallelism, when we already have ZeRO, I think there are situations when Pipeline Parallelism can be faster. Specifically if the network bandwidth is too slow for ZeRO-powered data parallel training or standard data parallel training. Combining pipeline parallel training with data parallel training can reduce the network traffic significantly, and help improve throughput when the network is a bottleneck. We actually did a comparative study on the pros and cons of these three techniques in the last press release:

https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/

under the subsection Understanding tradeoffs between data, model and pipeline parallelism under 3D parallelism.

stas00 commented 3 years ago

Your comment is phenomenal - thank you so much for explicitly covering all the grounds, @samyam. This is very helpful to understand how things stack together.

Here is current transformers reality:

We have the naive implementation of vertical MP as can be seen from the top diagram from the GPipe paper:

mp-pp

So we have everything as the pipeline except no rpc and we manage all the data and layer switching ourselves. Except it's very inefficient due to gpu idling, which is what PP solves. But I wasn't able to get this one working with DeepSpeed. Any ideas on whether we could do anything here? Or whether it'd be pointless and we need to sort out PP.

Wrt, pipeline, I'm working on trying to port transformers t5 to PP, but so far it's been a difficult fight. The problem is that transformers models pass a ton of non-tensors to forward and all 3 implementations - yours, fairscale and pytorch-1.8-to-be put a limit on inputs and outputs to be only a single tensor or a tuple of tensors, and nothing else. Which works for a simple logic but it so doesn't work for a complex model like most transformers models are. You can see my attempts at remedying it here: https://github.com/pytorch/pytorch/pull/50693 - I need to be able to pass None, (), tuples of tuples of tensors, and probably other complex structures. As a temp workaround I have come up with a way to encode some of the above into tensors and then decode them back on the forward side and then repeat in reverse for outputs. That pytorch PR also proposes a solution to all 3 implementations that will support almost any data and will make PP much easier to port to.

So my goal is hopefully to find a way to make PP work with transformers whichever if the 3 implementations can help me overcome the tensors-only limitation would be amazing! And then as you're saying I should be able to deploy DeepSpeed with it.

In terms of if we ever need Pipeline Parallelism or Model Parallelism, when we already have ZeRO, I think there are situations when Pipeline Parallelism can be faster.

That's super important - thank you for that insight. I'm looking forward to figuring out how to make these things work and then benchmark to see which is which.

Thank again for your indepth commentary, @samyam

stas00 commented 3 years ago

I have a follow up question:

This mpu object should implement a few methods that tells DeepSpeed which GPUs are involved in model parallelism and which are involved in data parallelism.

So basically what you're saying is that such setup will require at least 4 GPUs, since you can't use the same 2 gpus for MP and DP, correct? I think I remember seeing a diagram somewhere.

I don't see how 2D or 3D parallelism can be done on 2 GPUs. If I'm not mistaken each 1D requires 2x gpus, so 4 gpus for 2D and 8 gpus at least for 3D.

with DP+PP: We hide the fact that gpu 1 and 3 are used for PP, and so we tell DP that we only have gpu 0 and 2 and it's oblivious that the data it sends to gpus 0+2 also outsource to gpus 1+3.

And if we do 3D, I think the same story repeats, where we have to hide from each extra 1D the gpus that are used by other dimensions.

sdtblck commented 3 years ago

I have a similar problem - when i'm passing in mpu with a custom GPT2 Pipeline parallel model (building on megatron examples to get full 3d parallelism) to deepspeed.initialize, it gives me the following error:

File "/root/anaconda3/lib/python3.8/site-packages/deepspeed/__init__.py", line 121, in initialize │ assert mpu is None, "mpu must be None with pipeline parallelism"

Looking at the deepspeed initialize code, it seems it tries to grab mpu from the model?

engine = PipelineEngine(args=args, model=model, optimizer=optimizer, model_parameters=model_parameters, training_data=training_data, lr_scheduler=lr_scheduler, mpu=model.mpu(), dist_init_required=dist_init_required, collate_fn=collate_fn, config_params=config_params)

So should mpu be passed in to the Pipeline Module for building the model? can you provide any examples of how to properly acheive 3D parallelism?