huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
129.4k stars 25.67k forks source link

[Tensor Parallelism] Megatron-LM to transformers #10321

Open stas00 opened 3 years ago

stas00 commented 3 years ago

🚀 Feature request

Splitting the discussion that started here: https://github.com/huggingface/transformers/pull/10301#issuecomment-782917393 to add the potential future feature of transformers and it's Tensor Parallelism (Horizontal Model Parallelism) - for bigger context please see Parallelism notes.

Let's start with important clarification: MP can mean many different things

  1. Vertical MP - slice the layers vertically - one or more full layers placed on each gpu = Vertical MP - in which case VertMP is a simple version of PP with chunks=1
  2. Horizontal MP - slice the layers horizontally - place a slice of a full model on each gpu - Example Megatron-LM

At the moment I think it's only Megatron-LM that implements Horizontal MP. @anthon-l has ported that model to transformers, except the Horizontal MP parts, since currently transformers doesn't yet have support for it. There is already naive Vertical MP in t5 and gpt2 thanks to @alexorona's work, I ported Bart too but it's unmerged, and there is an ongoing effort to figure out how to implement the Pipeline. All these will have to co-operate with each other and also share common tools.

@anton-l started sharing what needs to be done to make that important feature available - and then down the road potentially make it available to other (all?) transformers models.

@anton-l, the floor is yours.

anton-l commented 3 years ago

@stas00 thanks for starting this thread!

I guess, in order for everyone to be on the same page, a brief explanation of horizontal parallelism is needed. This would be a good place for future reference and introduce other contributors to the core concepts.

NOTE for everyone reading: If you find any of the explanations below confusing, you can read about Megatron-LM in much more detail in its original paper: https://arxiv.org/pdf/1909.08053.pdf

The core idea

The main thing that separates Megatron-style (horizontal) parallelism from vertical parallelism is the way that it splits the model layers between GPUs without the need for idle time during training/inference (i.e. waiting while the previous GPUs complete their work on the previous layers of the model). This makes the whole process much more asynchronous, just like in MapReduce. Here's my rough sketch of how it looks: Model parallelism

Now the question is, how do we split the computation of those layers so that the parallelized model weights would be equivalent to the CPU ones?

Parallelized layers

Let's start with a simple building block of any transformer: a fully connected layer (nn.Linear) followed by a nonlinear activation (GeLU). Following the Megatron's paper notation, we can write the dot-product part of it as Y = GeLU(XA), where X and Y are the input and output vectors, and A is the weight matrix.

If we look at the computation in matrix form, it's easy to see how the matrix multiplication can be split between multiple GPUs: Parallel GEMM (1) Basically, if we split the weight matrix A column-wise across N GPUs and perform matrix multiplications XA_1 through XA_n in parallel, then we will end up with N output vectors Y_1, Y_2, ..., Y_n which can be fed into GeLU independently: image

Using this principle, we can update an MLP of arbitrary depth, without the need for any synchronization between GPUs until the very end, where we need to reconstruct the output vector from shards. The authors provide a helpful illustration for that: image

Quick note on self-attention

Parallelizing the multiheaded attention layers is even simpler, since they are already inherently parallel, due to having multiple independent heads! image

Practical implementation

If you want to just dive right in, here are the basic building blocks implemented in Megatron-LM:

All of these rely on basic Scatter, Gather and Reduce ops to split and aggregate the weight matrices. Thanks to PyTorch Distributed, we can use torch.distributed.all_reduce and all_gather for that, without having to worry about GPU synchronization. The scatter and gather layers just have to define appropriate forward and backward passes like so:

def _split(input_):
    world_size = get_tensor_model_parallel_world_size()
    input_list = split_tensor_along_last_dim(input_, world_size)
    rank = get_tensor_model_parallel_rank()
    output = input_list[rank].contiguous()
    return output

def _gather(input_):
    world_size = get_tensor_model_parallel_world_size()
    last_dim = input_.dim() - 1
    rank = get_tensor_model_parallel_rank()
    tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
    tensor_list[rank] = input_
    torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
    output = torch.cat(tensor_list, dim=last_dim).contiguous()
    return output

class ScatterToModelParallelRegion(torch.autograd.Function):
    def forward(ctx, input_):
        return _split(input_)

    def backward(ctx, grad_output):
        return _gather(grad_output)

class GatherFromModelParallelRegion(torch.autograd.Function):
    def forward(ctx, input_):
        return _gather(input_)

    def backward(ctx, grad_output):
        return _split(grad_output)

In a single transformer layer, there are 4 communication operations in total, for the forward and backward passes: image

Other things to consider

Parallelized embeddings and output logits

Since the weights of input and output embeddings of BERT/GPT2 are tied, they require a coordinated modification. In the original implementation, the input embedding matrix is parallelized along the vocabulary dimension (column-wise), and the output embeddings' matrix multiplications is parallelized together with the cross-entropy loss to reduce the communication size (see end of section 3 in the paper):

Hybrid model and data parallelism

Combining horizontal parallelism with data parallelism requires grouping the GPUs in a specific way, as described in appendix B.1: image

anton-l commented 3 years ago

Phew! That felt like a start of a whole blog post :smile:

As for porting all of this, I would follow fairseq's example and copy Megatron-LM's parallel layers verbatim into an existing (but separate) implementation of BertModel or GPT2Model as a proof-of-concept and then work from there.

After the first semi-working prototype we could figure out how to implement the switching mechanism between a homogeneous model and a parallelized one, but it's too early to think about that, IMO. What do you think, @stas00 ?

stas00 commented 3 years ago

Amazing! Thank you for this awesome presentation, @anton-l! This could totally be a great blog post - I agree!

Let me study the information you shared and I will follow up then!

Until then I have a quick suggestion: Do you have an easy access to 2 gpus? That would be enough to make a PoC work and then we can find a larger cluster with more gpus to experiment on and eventually port the 8 splits from fairseq.

I suppose it'd be easier to implement this for Megatron-LM, but the main use would be t5 and gpt2 where we have most huge models at the moment. So we could start there as well. If it works for you. Which also can be worked on independently of your Megatron-LM PR.

anton-l commented 3 years ago

Regarding the setup: I can borrow a second gpu for the time being, that shouldn't be a problem :)

As for the models, I think GPT2 is a good candidate for our experiments, since the transformers' implementation is already stable and has multiple smaller checkpoints for quick demos.

Also, I don't think we should even be too concerned about porting the 8 original splits of fairseq's megatron, since I've already concatenated them for the model's PR. If everything was done correctly, this potentially allows us to create an arbitrary split across 2^n devices, not just 8.

stas00 commented 3 years ago

Sounds good on all accounts. GPT2 would be perfect, @anton-l!

I had the same thought about just splitting your merged model if needed.

Please let us know how we can support you in this endeavor.

just for you to be aware, I mentioned in the other thread the DeepSpeed version of their Megatron-LM port - perhaps theirs is newer - I haven't had a chance to study it yet. https://github.com/jeffra/DSE/tree/master/megatron-lm . You can diff the different versions against the baseline - that is I assume it has been changed - perhaps it hasn't. If you want to have a look, if not, it is good too. It will be good to start anywhere.

morganmcg1 commented 3 years ago

@anton-l Thanks for the great work on this, its really nice to be able to load the pretrained model so thanks for that too! Did you have any progress on fine-tuning across multiple GPUs? Would love to see if the results get any better with some fine-tuning...

stas00 commented 3 years ago

@anton-l, let's do it if you have resources and interest? Let me know how I can be of help.

Now having used Megatron-LM in big science experiments it's time to port it to transformers.

adit299 commented 1 year ago

@stas00 @anton-l Just curious, is Megatron-LM now ported to transformers? Or the proof of concept mentioned in:

As for porting all of this, I would follow fairseq's example and copy Megatron-LM's parallel layers verbatim into an existing (but separate) implementation of BertModel or GPT2Model as a proof-of-concept and then work from there.

I would love to work on this issue, if there is anything I could do!

brynhayder commented 5 months ago

Thanks for nice the overview.

Having read the paper, I disagree with the following statement (emphasis mine)

Using this principle, we can update an MLP of arbitrary depth, without the need for any synchronization between GPUs until the very end

If you split one layer inputs across rows, then the outputs are split across columns, so you need to split the second layer weights across rows, then you need to gather outputs before applying a non-linearity. This is explained in Section 3 of the paper.