jasperzhong / cs-notes

CS认知体系
6 stars 0 forks source link

learn Megatron-LM #23

Closed jasperzhong closed 3 years ago

jasperzhong commented 3 years ago

https://github.com/NVIDIA/Megatron-LM

刻不容缓.

jasperzhong commented 3 years ago

megatron/arguments.py

  1. parse_args函数

先确定tensor-model-parallel size (t),然后确定pipeline-model-parallel size (p),得到model-parallel size (t p). 最后确定data-parallel size (d = n // (tp))

log里面可以看到使用的parallel configuration. image

jasperzhong commented 3 years ago

megatron/training.py

pretrain函数. 很多训练脚本的入口.

  1. 得到model, optimizer和lr_scheduler (setup_model_and_optimizer函数)

通信分两种: torchDDP, localDDP. localDDP就是不带overlap的通信. 不是很懂为什么搞一个这样的东西.文档说在某些情况下,localDDP还更好. 额.

注意: 如果使用了pipeline parallelism,必须使用localDDP.

fp16训练会使用这个classFloat16OptimizerWithFloat16Params. 看上去把apex里面做的事情又写了一遍?不过看上去支持bf16.

loss用到了lm_loss和sop_loss. sop_loss参见 https://github.com/vycezhong/read-papers/issues/19

  1. 调用train函数.

train函数主要逻辑在train_step函数. 比较核心的逻辑. 流程无非是: forward + backward (overlapped with communication) + backward-embedding-all-reduce + update.

对于backward-embedding-all-reduce注释是这么说的:

# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).

一般来讲,embedding layer和最后一层(MLP?)做一个weight sharing,但是在pipelined model parallelism里面,这两层不在一个worker上,所以需要在最后一个stage上创建一个word_embedding copy,而且这两个copy需要一个额外的all-reduce操作(对grad)保证参数一致. 确实.

jasperzhong commented 3 years ago

mpu: model parallel utility

mpu/mappings.py

Megatron-LM论文里面描述过两个共轭的函数:

一种是forward的时候copy,backward的时候all-reduce.

class _CopyToModelParallelRegion(torch.autograd.Function):
    """Pass the input to the model parallel region."""

    @staticmethod
    def symbolic(graph, input_):
        return input_

    @staticmethod
    def forward(ctx, input_):
        return input_

    @staticmethod
    def backward(ctx, grad_output):
        return _reduce(grad_output)

一种是forward的时候all-reduce,backward的时候copy.

class _ReduceFromModelParallelRegion(torch.autograd.Function):
    """All-reduce the input from the model parallel region."""

    @staticmethod
    def symbolic(graph, input_):
        return _reduce(input_)

    @staticmethod
    def forward(ctx, input_):
        return _reduce(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

这两个函数背后的原因其实很好理解. 可以把all-reduce理解为一个sum操作,直接return _input理解为copy操作. copy操作的结果分到两个worker上,梯度来自于两部分,所以梯度需要做sum. 而sum操作的结果的梯度直接copy就行了.

还有一种函数是forward的时候all-gather,backward的时候scatter.

class _GatherFromModelParallelRegion(torch.autograd.Function):
    """Gather the input from model parallel region and concatinate."""

    @staticmethod
    def symbolic(graph, input_):
        return _gather(input_)

    @staticmethod
    def forward(ctx, input_):
        return _gather(input_)

    @staticmethod
    def backward(ctx, grad_output):
        return _split(grad_output)

mpu/layers.py

  1. ColumnParallelLinear

语义: Y = XA + b. A沿column partition. 这样Y的column也被partition了. 可选是否gather_output. 如果是True,则会调用all-gather使得每个GPU都有完整的Y; 否则,每个GPU都只拿到对应的partition.

Broadcast x Split(1) -> Split(1)

代码如下:

    def forward(self, input_):
        # Set up backprop all-reduce.
        input_parallel = copy_to_tensor_model_parallel_region(input_)
        # Matrix multiply.

        bias = self.bias if not self.skip_bias_add else None
        output_parallel = F.linear(input_parallel, self.weight, bias)
        if self.gather_output:
            # All-gather across the partitions.
            output = gather_from_tensor_model_parallel_region(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

注意:

  1. RowParallelLinear

语义: Y = XA + b. A沿row partition,X沿column partition. 这样XA的结果需要做all-reduce.

Split(1) x Split(0) -> PartialSum

代码如下:

    def forward(self, input_):
        # Set up backprop all-reduce.
        if self.input_is_parallel:
            input_parallel = input_
        else:
            input_parallel = scatter_to_tensor_model_parallel_region(input_)
        # Matrix multiply.
        output_parallel = F.linear(input_parallel, self.weight)
        # All-reduce across all the partitions.
        output_ = reduce_from_tensor_model_parallel_region(output_parallel)
        if not self.skip_bias_add:
            output = output_ + self.bias if self.bias is not None else output_
            output_bias = None
        else:
            output = output_
            output_bias = self.bias
        return output, output_bias

注意:

  1. VocabParallelEmbedding

为了memory均衡,embedding也会做shard,而不是放在一个GPU上. 按照vocab维度做partition,每个worker拿到embedding的一部分,所以一部分输入找不到对应的embedding,所以embedding输出需要做一个all-reduce使得结果完整.

代码如下:

    def forward(self, input_):
        if self.tensor_model_parallel_size > 1:
            # Build the mask.
            input_mask = (input_ < self.vocab_start_index) | \
                         (input_ >= self.vocab_end_index)
            # Mask the input.
            masked_input = input_.clone() - self.vocab_start_index
            masked_input[input_mask] = 0
        else:
            masked_input = input_
            # Get the embeddings.
        output_parallel = F.embedding(masked_input, self.weight,
                                      self.padding_idx, self.max_norm,
                                      self.norm_type, self.scale_grad_by_freq,
                                      self.sparse)
        # Mask the output embedding.
        if self.tensor_model_parallel_size > 1:
            output_parallel[input_mask, :] = 0.0
        # Reduce across all the model parallel GPUs.
        output = reduce_from_tensor_model_parallel_region(output_parallel)
        return output

注意:


mpu/cross_entropy.py

这个理解起来有点难度,有些细节没看懂,大致idea是可以理解的. idea在Megatron-LM论文里面有提到: all-reduce loss (size = b x s)而不是logits (size = b x s x v).


mpu/random.py

CheckpointFunction: gradient checkpoint实现. 注意要保存forward时候的rng_state. 在backward的时候,首先,为了保证recompute的结果和forward时候是一样的,需要restoreforward时候rng_state. 然后再restore backward的rng_state,最后做backward.


mpu/data.py

broadcast_data: 把input_data从rank 0 broadcast到所有tensor-model-parallel ranks上. 在pretrain_bert.py的get_batch函数里面调用了这个函数.

这和data parallelism不一样,data parallelism是不同rank load不同数据. tensor-model-parallel不同rank必须load同样的数据.

不过我不是很清楚为什么要通过broadcast来实现. 有点奇怪.


mpu/initialize.py

定义了几个global data:

# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None

其中最重要的函数是initialize_model_parallel,主要是初始化上述几个global data. . 其他函数基本是上面几个global data的getter/setter.

一个group大小就是parallelism dim size. 这个parallelism group的数量就是 #GPU / parallelism dim size. 比如,n = 16,p = 4,那么一个pipeline group有4张卡,一共有4个pipeline groups. 以此类推.

分配group的逻辑在 https://github.com/vycezhong/read-papers/issues/188 这篇论文有所讲述. 比如对于(p, t, d) = (4, 2, 2)的一个parallel configuration,分配结果如下: image

那篇论文没有给出具体的分配group算法. 通过这部分实现代码可以略窥一二.

    # Build the data-parallel groups.
    global _DATA_PARALLEL_GROUP
    assert _DATA_PARALLEL_GROUP is None, \
        'data parallel group is already initialized'
    all_data_parallel_group_ranks = []
    for i in range(pipeline_model_parallel_size):
        start_rank = i * num_pipeline_model_parallel_groups
        end_rank = (i + 1) * num_pipeline_model_parallel_groups
        for j in range(tensor_model_parallel_size):
            ranks = range(start_rank + j, end_rank,
                          tensor_model_parallel_size)
            all_data_parallel_group_ranks.append(list(ranks))
            group = torch.distributed.new_group(ranks)
            if rank in ranks:
                _DATA_PARALLEL_GROUP = group

首先,分为p个pipeline stages,stage 1的rank范围是 [0, n//p),stage 2的rank范围是[n//p, 2n//p) .... stage p的rank范围是[(p-1)n//p, n). 每个pipeline stage有n//p个GPUs.

对于每一个stage,ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)代表从n//p个GPUs中,隔t个取一个作为data-parallel group. 所以每个data-parallel group大小为 n // p // t = d.

    # Build the tensor model-parallel groups.
    global _TENSOR_MODEL_PARALLEL_GROUP
    assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
        'tensor model parallel group is already initialized'
    for i in range(num_tensor_model_parallel_groups):
        ranks = range(i * tensor_model_parallel_size,
                      (i + 1) * tensor_model_parallel_size)
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
            _TENSOR_MODEL_PARALLEL_GROUP = group

从中可以看出,tensor-model-parallel group的rank一定是相邻的,比如(0, 1), (2, 3)等.

global _PIPELINE_MODEL_PARALLEL_GROUP
    global _PIPELINE_GLOBAL_RANKS
    assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
        'pipeline model parallel group is already initialized'
    global _EMBEDDING_GROUP
    assert _EMBEDDING_GROUP is None, \
        'embedding group is already initialized'
    for i in range(num_pipeline_model_parallel_groups):
        ranks = range(i, world_size,
                      num_pipeline_model_parallel_groups)
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
            _PIPELINE_MODEL_PARALLEL_GROUP = group
            _PIPELINE_GLOBAL_RANKS = ranks
        # Setup embedding group (to exchange gradients between
        # first and last stages).
        if len(ranks) > 1:
            embedding_ranks = [ranks[0], ranks[-1]]
        else:
            embedding_ranks = ranks
        group = torch.distributed.new_group(embedding_ranks)
        if rank in embedding_ranks:
            _EMBEDDING_GROUP = group

pipeline-model-parallel group是隔p个取一个,比如[0, 4, 8, 12].

jasperzhong commented 3 years ago

pipeline parallelism implementation

p2p_communication.py

pipeline parallelism需要inter-stage的P2P通信. 主要实现是_communnicate函数. 其他几个API都是调用该函数.

这个函数的注释写得不错,解释得非常清楚. 实现的功能是在stages中双向send/recv.

def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
                 use_ring_exchange=False, tensor_shape=None,
                 override_scatter_gather_tensors_in_pipeline=False,
                 dtype_=None):
    """Communicate tensors between stages. Used as helper method in other
    communication methods that are used in megatron/schedules.py.

    Takes the following arguments:
        tensor_send_next: tensor to send to next rank (no tensor sent if
                          set to None).
        tensor_send_prev: tensor to send to prev rank (no tensor sent if
                          set to None).
        recv_prev: boolean for whether tensor should be received from
                   previous rank.
        recv_next: boolean for whether tensor should be received from
                   next rank.
        use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
                           API should be used.
        tensor_shape: optional, use when the input sequence contains less
                      tokens than the default sequence length
        override_scatter_gather_tensors_in_pipeline: optional, this is used
                                                     when tensor_shape is
                                                     provided to overwide
                                                     scatter gather tensors
        dtype_: optional, this is used when tensor_shape is provied and what
                is the type of tensor_shape
    Returns:
        (tensor_recv_prev, tensor_recv_next)
    """

实现:

  1. 如果需要recv,会临时创建一个empty tensor作为buffer,并且返回.
    if recv_prev:
        tensor_recv_prev = torch.empty(tensor_chunk_shape,
                                       requires_grad=requires_grad,
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)
    if recv_next:
        tensor_recv_next = torch.empty(tensor_chunk_shape,
                                       requires_grad=requires_grad,
                                       device=torch.cuda.current_device(),
                                       dtype=dtype)
  1. 使用了torch.distributed.batch_isend_irecv,批量做异步send/recv (实际上就是个for循环...) 然后调用wait()进行同步.
        ops = []
        if tensor_send_prev is not None:
            send_prev_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor_send_prev,
                mpu.get_pipeline_model_parallel_prev_rank())
            ops.append(send_prev_op)
        if tensor_recv_prev is not None:
            recv_prev_op = torch.distributed.P2POp(
                torch.distributed.irecv, tensor_recv_prev,
                mpu.get_pipeline_model_parallel_prev_rank())
            ops.append(recv_prev_op)
        if tensor_send_next is not None:
            send_next_op = torch.distributed.P2POp(
                torch.distributed.isend, tensor_send_next,
                mpu.get_pipeline_model_parallel_next_rank())
            ops.append(send_next_op)
        if tensor_recv_next is not None:
            recv_next_op = torch.distributed.P2POp(
                torch.distributed.irecv, tensor_recv_next,
                mpu.get_pipeline_model_parallel_next_rank())
            ops.append(recv_next_op)
        if len(ops) > 0:
            reqs = torch.distributed.batch_isend_irecv(ops)
            for req in reqs:
                req.wait()
  1. 有一个trick - scatter-gather optimization. 论文里面提到过这个优化. 在使用了tensor model parallelism的时候,最后一层的输出是重复的 (因为做了all-reduce),所以可以split成t份然后再发,下游收到后再做一个gather得到完整的数据.
    # Split tensor into smaller chunks if using scatter-gather optimization.
    if not override_scatter_gather_tensors_in_pipeline and \
            args.scatter_gather_tensors_in_pipeline:
        if tensor_send_next is not None:
            tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)

        if tensor_send_prev is not None:
            tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)
    ...
    # If using scatter-gather optimization, gather smaller chunks.
    if not override_scatter_gather_tensors_in_pipeline and \
            args.scatter_gather_tensors_in_pipeline:
        if recv_prev:
            tensor_recv_prev = mpu.gather_split_1d_tensor(
                tensor_recv_prev).view(tensor_shape).requires_grad_()

        if recv_next:
            tensor_recv_next = mpu.gather_split_1d_tensor(
                tensor_recv_next).view(tensor_shape).requires_grad_()

_communicate函数支持双向的send/recv,实际的用法如下:

API send_next send_prev recv_prev recv_next
recv_forward
recv_backward
send_forward
send_backward
send_forward_recv_backward
send_backward_recv_forward

另外对于pipeline的first stage和last stage需要特判一下. 比如recv_forward对于first stage就不需要做,recv_backward对于last stage也不需要做.


schedules.py

get_forward_backward_func函数用于选择何种pipeline schedule.

def get_forward_backward_func():
    args = get_args()
    if mpu.get_pipeline_model_parallel_world_size() > 1:
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
            assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
                'number of microbatches is not divisible by pipeline-parallel ' \
                'size when using interleaved schedule'
        else:
            forward_backward_func = forward_backward_pipelining_without_interleaving
    else:
        forward_backward_func = forward_backward_no_pipelining
    return forward_backward_func

NVIDIA在SC '21上的论文描述了一种interleaving pipeline schedule,能够进一步降低bubble size,但是增加了一定的通信开销. 本文暂不考虑这种特殊的schedule,主要研究PipeDream-Flush的实现,即上面代码里面的forward_backward_pipelining_without_interleaving函数.

image

PipeDream-Flush的一个iteration分为三个阶段:

BTW,1F1B schedule是memory-efficient的. 因为1F1B schedule将in-flight microbatches的数量限制到pipeline depth (p),而不是number of microbatches (m) (e.g., GPipe). 一般来讲,为了降低bubble time,m >> p.

  1. setup

首先确定每个worker在warm-up phase的microbatches数量,为m - rank - 1,即随着rank依次递减. last stage warm-up所需的microbatches数量为零,即直接开始steady阶段.

    num_warmup_microbatches = \
        (mpu.get_pipeline_model_parallel_world_size() -
         mpu.get_pipeline_model_parallel_rank() - 1)
    num_microbatches_remaining = \
        num_microbatches - num_warmup_microbatches

每个worker还需要建立一个FIFO队列,用于保存来自上游的activation(input_tensor)和向下游发送的activation (output_tensor). 这些保存的activations将用于反向传播.

    # Input, output tensors only need to be saved when doing backward passes
    input_tensors = None
    output_tensors = None
    if not forward_only:
        input_tensors = []
        output_tensors = []
  1. warm-up phase

    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
        input_tensor = p2p_communication.recv_forward(timers=timers)
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        p2p_communication.send_forward(output_tensor, timers=timers)
    
        if not forward_only:
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)

    注意:

    • first stage没有上游stage,调用recv_forward直接返回None.
    • 将来自上游的activations (input_tensor) 作为这个stage的输入.
    • 每个worker保存来自上游的activations (input_tensor) 和发向下游的activations (output_tensor)到队列.
  2. steady phase 逻辑: forward -> send forward & recv backward -> backward -> send backward & recv forward

    # Before running 1F1B, need to receive first forward tensor.
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_remaining > 0:
        input_tensor = p2p_communication.recv_forward(timers=timers)

    # Run 1F1B in steady state.
    for i in range(num_microbatches_remaining):
        last_iteration = (i == (num_microbatches_remaining - 1))

        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                     input_tensor, losses_reduced)
        if forward_only:
            p2p_communication.send_forward(output_tensor, timers=timers)

            if not last_iteration:
                input_tensor = p2p_communication.recv_forward(timers=timers)

        else:
            output_tensor_grad = \
                p2p_communication.send_forward_recv_backward(output_tensor,
                                                             timers=timers)

            # Add input_tensor and output_tensor to end of list.
            input_tensors.append(input_tensor)
            output_tensors.append(output_tensor)

            # Pop input_tensor and output_tensor from the start of the list for
            # the backward pass.
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

            if last_iteration:
                input_tensor = None
                p2p_communication.send_backward(input_tensor_grad, timers=timers)
            else:
                input_tensor = \
                    p2p_communication.send_backward_recv_forward(
                        input_tensor_grad, timers=timers)

注意:

另外论文的示意图有点小问题,部分forward pass的时间应当后移到backward后. 已提issue.

  1. cooldown phase

和warm-up phase对称,执行次数也是num_warmup_microbatches,只不过是专门做backward.

    # Run cooldown backward passes.
    if not forward_only:
        for i in range(num_warmup_microbatches):
            input_tensor = input_tensors.pop(0)
            output_tensor = output_tensors.pop(0)

            output_tensor_grad = p2p_communication.recv_backward(timers=timers)

            input_tensor_grad = \
                backward_step(optimizer, input_tensor, output_tensor,
                              output_tensor_grad)

            p2p_communication.send_backward(input_tensor_grad, timers=timers)

注意: 这个phase清理未完成的backward,所以只需要pop队列就行了.

上述三个phase执行结束后,input_tensorsoutput_tensors都应为空.

注意到,这里对于pipeline的实现,单个worker的通信和计算是没有overlap的. 因为send和recv都是阻塞的,发送的消息必须被上下游接收后才能进行下一步计算.


model/distributed.py

使用Pipeline parallelism,data parallelism的optimizer只能用LocalDDP. (见training.py)

    # We only support local DDP with multiple micro-batches.
    if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
        assert args.DDP_impl == 'local'

实现在DistributedDataParallel类. 从arguments.py的默认参数来看,use_contiguous_buffers这个flag默认是True. 那么会开一份连续的buffer用于通信,记作main_grad. 并且会注册一个backward hook,用于accumulate gradients. 相关代码如下:

class DistributedDataParallel(DistributedDataParallelBase):
    def __init__(self, module,
                 accumulate_allreduce_grads_in_fp32,
                 use_contiguous_buffers):
         ...
        self._grad_buffers = None
        if self.use_contiguous_buffers:
            self._grad_buffers = {}
            ...
            # Allocate the buffer.
            for dtype, num_elements in type_num_elements.items():
                self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype)
            ...
            # Backward hook.
            # Accumalation function for the gradients. We need
            # to store them so they don't go out of scope.
            self.grad_accs = []
            # Loop over all the parameters in the model.
            for param in self.module.parameters():
                if param.requires_grad:
                    # Expand so we get access to grad_fn.
                    param_tmp = param.expand_as(param)
                    # Get the gradient accumulator functtion.
                    grad_acc = param_tmp.grad_fn.next_functions[0][0]
                    grad_acc.register_hook(self._make_param_hook(param))
                    self.grad_accs.append(grad_acc)

    def _make_param_hook(self, param):
        """Create the all-reduce hook for backprop."""
        # Hook used for back-prop.
        def param_hook(*unused):
            # Add the gradient to the buffer.
            if param.grad.data is not None:
                param.main_grad.add_(param.grad.data)
                # Now we can deallocate grad memory.
                param.grad = None
        return param_hook

    def allreduce_gradients(self):
        """Reduce gradients across data parallel ranks."""
        # If we have buffers, simply reduce the data in the buffer.
        if self._grad_buffers is not None:
            for _, buffer_ in self._grad_buffers.items():
                buffer_.data /= mpu.get_data_parallel_world_size()
                torch.distributed.all_reduce(
                    buffer_.data, group=mpu.get_data_parallel_group())
        ...

其中,MemoryBuffer就是torch.zeros创建的一个in-memory buffer on GPU.

所以LocalDDP是没有计算和通信的overlap的. 其实pipeline parallelism完全可以用torchDDP,设置下gradient accumulation steps为m就行了. 因为pipeline parallelism把一个minibatch拆成m个microbatches,和gradient accumulation没有区别.

另外我不清楚为什么要搞一个continuous buffer. 注释里面有一句:

has the potential to reduce memory fragmentation.

还是不理解...


model/transformer.py model/language_model.py model/bert_model.py

pipeline parallelism最后一块 —— model partition. 其实也没啥好说的,因为现在的大模型都是repetitive的,所以直接按照层数切分,每一层是一模一样的transformer layer. 代码见model/transformer.py.

        self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()

不过first stage和last stage有点特殊.

  1. first stage需要包括embedding,叫pre_process. 代码见model/language_model.py.
        # Embeddings.
        if self.pre_process:
            self.embedding = Embedding(self.hidden_size,
                                       args.padded_vocab_size,
                                       args.max_position_embeddings,
                                       args.hidden_dropout,
                                       self.init_method,
                                       self.num_tokentypes)
            self._embedding_key = 'embedding'
  1. last stage需要包括计算loss的部分,叫post_preprocess. 代码见model/bert_model.py
        if self.post_process:
            self.lm_head = BertLMHead(
                self.word_embeddings_weight().size(0),
                args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
            self._lm_head_key = 'lm_head'
            self.binary_head = None
            if self.add_binary_head:
                self.binary_head = get_linear_layer(args.hidden_size, 2,
                                                    init_method)
                self._binary_head_key = 'binary_head'

training.py

最后回过头看下train_step函数,看看single training step的流程:

  1. zero grad
    # Set grad to zero.
    if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
        for partition in model:
            partition.zero_grad_buffer()
    optimizer.zero_grad()
  1. forward & backward
    forward_backward_func = get_forward_backward_func()
    losses_reduced = forward_backward_func(
        forward_step_func, data_iterator, model,
        optimizer, timers, forward_only=False)

到这里整个pipeline都已经forward完成并且backward完成了,loss和grad都算好了.

  1. all-reduce
    # All-reduce if needed.
    if args.DDP_impl == 'local':
        timers('backward-params-all-reduce').start()
        for model_module in model:
            model_module.allreduce_gradients()
        timers('backward-params-all-reduce').stop()

到这里才开始做data parallelism的all-reduce.

  1. embedding-allreduce

因为embedding做了weight sharing. first stage和last stage都有一份embedding,为了保证参数一致,需要对二者的grad做all-reduce. 其他stage忽略此过程.

  1. update

调用optimizer.step()更新参数.

注意: pipeline中不同stage更新参数是不同步的,完全可以有先后. 其实并没有一个explicit pipeline flush. 这个同步推迟到了下一个iteration recv等待的时候 (看来first stage应该是bottleneck).

最后分析下各个阶段时间. 配置: (p, t, d) = (2, 2, 2), gpu_per_node = 2, m = 2.

注意,是last_rank的worker print log,同时也是pipeline last stage. image

从log来看,看上去recv_forward占了很长的时间,这其实是warm-up phase的时间,因为这是pipeline last stage的log.

另外backward-embedding-all-reduce也很高,也主要是等待时间,等待first stage做完backward做完data parallelism的all-reduce.

只有backward-send-forward-recvbackward-send能代表真实的通信时间. 可以看到时间非常短,比data parallelism的all-reduce时间短很多!

optimizer的时间有点出乎意料. 70%的时间用来处理mixed precision的问题. 哈.

最后,backward其实不比forward慢多少,forward:backward = 1:2夸张了,这里差不多是forward:backward = 7:10.