pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.5k stars 483 forks source link

[RFC] Pipeline parallelism for Pytorch/XLA #6347

Open YangFei1990 opened 10 months ago

YangFei1990 commented 10 months ago

🚀 Description

Pipeline parallelism is a technique used in deep learning model training to improve efficiency and reduce the training time of large neural networks. Here we propose a pipeline parallelism solution for Pytorch/XLA with the following goals

Solution

NeuronxDistributed(a pytorch/xla based library) offers a FX partition based pipeline parallelism solution which satisfies both objectives mentioned above. The solution traces the entire model to a static FX graph and partitions the graph into a series of submodules based on user annotations. FX’s GraphModule is an nn.Module generated from an fx.Graph. GraphModule has a graph attribute, as well as code and forward attributes generated from that graph. The output of partition is a list of graph modules that each belongs to a PP rank. These graph modules can be executed in the same way as native torch.nn.Modules and traced by XLA. During runtime, each PP rank will only run the graph module that belongs to it with the predefined PP schedule. Such execution will be traced by XLA and compiled into executable programs and then executed on XLA device. In general there is no fundamental difference compared with how XLA device train models today.

Interacting with other parallelism techniques

Tensor parallelism

Since Tensor-parallelism operates within a module and since the current PP solution cuts at the layer boundary by keeping the module untouched, the PP solution can work seamlessly with TP. The same solution is implemented inside the Neurox-Distributed library, that uses megatron style TP APIs and FX based PP.

ZeRO/FSDP

We have support with ZeRO1 optimizer, i.e. optimizer state sharding. For ZeRO2/3/FSDP we do not support yet.

GSPMD

The way GSPMD works is to mark the tensor as sharded and the CC ops will be inserted during XLA passes. Since our solution is doing PP partition purely on Pytorch framework layer, the sharded tensor should remain unchanged after PP partition. The PP send/recv op will be in different graph compared with the model execution graph, so that GSPMD can insert corresponding CC ops into the model execution graph. Based on above analysis, we envision non-significant effort required to support GSPMD with our solution. For example, we might need to build support when certain tensor that is marked as sharded will be communicated between PP stages.

APIs

from neuronx_distributed import NxDPPModel, initialize_model_parallel

# Initialize model parallelism
initialize_model_parallel(tensor_parallel_size, pipeline_parallel_size)

# User specified torch model
model = transformers.LlamaForCausalLM(config)

# Apply model wrapper, which will trace and partition the model based on PP size
model = NxDPPModel(
    model, # native PT model
    transformer_layer_cls=LlamaDecoderLayer, # transformer layer class that is used as pipeline cutting boundaries
    num_microbatches=args.num_microbatches, # number microbatch for pipeline
    output_loss_value_spec=(True, False), # data structure to get loss from model output
    ...
)
...
# NxDPPModel.run_train will use pipeline engine to execute a batch in pipeline manner
loss = model.run_train(
        input_ids=input_ids,
        attention_mask=attention_mask,
        labels=labels,
    )

Use cases

Supporting pipeline parallelism for a generic model is challenging. It becomes a much simpler problem if we focus on Transformer based models that contain a linear list of modules/ModuleList/Sequential, with the observation that most PP workloads are Transformer based large language models(LLM). Such assumption not only makes it easy to implement the pipeline partition and runtime, but also relax certain tracing overhead. While we are focusing on transformer based models, the design principle itself is generic and could be extended to any model architectures.

Current/Future work

Proposal

We would like to understand the plan and roadmap of Pytorch team for PP support and seek for collaboration/integration. Though NeuronxDistributed's PP solution is aimed to run on XLA device, but most of the components are hardware agnostic and it is easy to decouple the component that requires special XLA treatment, mainly the communication part. It will be beneficial to have a Pytorch first party PP solution that can work for both GPU and XLA device. We have observed that Pytorch has been developing a first party solution for pipeline parallelism, namely Pippy. Pippy embrace the same idea as our solution, i.e. tracing and do PP partition on FX graph. Previously Pippy's runtime is based on Pytorch RPC, which makes it hard to run on XLA devices. However recently we found that there is a refactor of Pippy which removed the RPC backend and switched to torch c10d. Such refactor makes it much easier to integrate with our solution.

Additional context

cc @JackCaoG @aws-rhsoln @amithrm Please also tag folks that are relevant.

JackCaoG commented 10 months ago

FYI @wconstab @kwen2501

kwen2501 commented 10 months ago

Currently we use FX's symbolic tracer to do tracing. We have plan to move tracing into torch Dynamo for better model coverage and potentially better performance.

PiPPy migrated to PT2 tracer too :) See https://github.com/pytorch/PiPPy/pull/873

YangFei1990 commented 10 months ago

Currently we use FX's symbolic tracer to do tracing. We have plan to move tracing into torch Dynamo for better model coverage and potentially better performance.

PiPPy migrated to PT2 tracer too :) See pytorch/PiPPy#873

Yeah we would like to do the same once the XLA CC ops are traceable for Dynamo.

JackCaoG commented 10 months ago

@alanwaketan I think there is no real blocker for our cc ops to be traceable even with groups? I remembered we didn't implement it because we didn't need it at that time.

alanwaketan commented 10 months ago

@alanwaketan I think there is no real blocker for our cc ops to be traceable even with groups? I remembered we didn't implement it because we didn't need it at that time.

Right, there is no blockers. There was no immediate needs in our ends back then and now our distributed story is shifted to focus on SPMD. So, feel free to contribute to it, @YangFei1990! If you need anything, I can share the traceable cc design doc with you.

yeounoh commented 10 months ago

Thanks @YangFei1990 !

yitongh commented 10 months ago

Hello, @YangFei1990. We have also implemented pipeline parallelism on GPU based on XLA. In essence, we trace the user's model using torch.fx symbolic trace and split it into PP stages based on the split points provided by the user. Then we run the split fx graph, and finally, like other torch modules, we lower it to XLA through the LTC mechanism. We implemented the PipeDream algorithm for PP based on DeepSpeed's PP implementation. We did not use PiPPy because PiPPy relies on RPC for scheduling, which is heavy. Currently, in our PP implementation, we encounter the following issues:

  1. We need to implement send/recv XLA operations. Currently, the support for send/recv operators in PyTorch/XLA at the python is not yet complete. On the GPU side, we need to pass the destination rank for send and the source rank for receive. To use them together with other communication operators, we need to use tokens. However, XLA does not yet support tokens when implementing operators like all-reduce; it only supports numeric tokens. The tokens for send/recv are not numeric, which means we either need to make operators like all-reduce support tokens, or make send/recv support numeric tokens.

  2. During the execution process of pipeline parallelism, we will run many micro-batches, for instance, 64. A higher number of micro-batches can help mask the bubble problem inherent to PP. In large language model scenarios, this issue results in an excessively large graph if we trace the entire PP execution process at once, leading to very long compilation times. However, if we compile after running a few micro-batches of fwd and bwd, due to the PipeDream algorithm, some of the activations from the fwd passes will be used as inputs for subsequent bwd passes. This causes these activations to remain unreleased during the execution of bwd, leading to an increase in GPU memory usage. This is also why we need to implement send/recv in XLA. If we don't, we would have to compile and execute each fwd and bwd pass for every micro-batch, leading to increased GPU memory usage.

Have you experienced any of the issues mentioned above? Very happy to have the opportunity to discuss them together.

yitongh commented 10 months ago

I have opened an issue regarding the token issue of send/recv and other communication operators in XLA. Welcome to join the discussion.

YangFei1990 commented 10 months ago

Hi @yitongh

  1. We use allreduce for send/recv so all CCs in our execution are float token based. You can refer to our send/recv ops: https://github.com/aws-neuron/neuronx-distributed/blob/main/src/neuronx_distributed/pipeline/comm.py#L38
  2. I do not understand why your graph size scales with the number of microbatches. Each microbatch should have the same graph for the same PP rank. When you say PipeDream algorithm, I assume you are doing 1F1B schedule when you are referring to PipeDream scheduler? 1F1B schedule naturally introduce some imbalance in the activation memory usage between stages, I do not think there is a way to avoid that.
JackCaoG commented 10 months ago

Question for both @yitongh and @YangFei1990, is each pipeline stage HLO identical in your use case?

YangFei1990 commented 10 months ago

@JackCaoG Each pipeline stage should have different graph. Naturally it is MPMD, e.g embedding only on the first PP rank and lm_head only on the last on.

yitongh commented 10 months ago

@JackCaoG Just like @YangFei1990 , the HLO at each of pipeline stages is different.

yitongh commented 10 months ago

@YangFei1990 Thanks for your reply.

  1. When using all reduce for communication, how is your communication performance compared to using send/recv? With send/recv, an async send is an operation that can be ignored for the current stage, and recv is also a cheaper operator than all reduce. However, I believe that in the PP scenario, even if all reduce is used for PP communication, the performance issue is not significant, and the main bottleneck is still the bubble.

  2. I'm curious about the PP algorithm you're using and how many micro batches you run PP with. Yes, we use the 1F1B scheduling strategy from PipeDream. Under this strategy, we can enable a relatively large number of micro batches, so that the bubble issue in PP becomes negligible, such as when running 128 micro batches with a PP count of 4. In this scenario, we need to consider how to add mark_step. If we add a mark_step only after all micro batches have been executed, we will end up with a very large graph. If we add a mark_step after running a few 1F1B cycles, because the fwd and bwd in 1F1B are not from the same micro batch, some of the temporary activations from fwd will be used by subsequent bwd, causing these activations to serve as inputs for bwd and preventing their release during the execution of bwd. Let's consider a simple scenario where we add a mark_step after fwd and another mark_step after bwd. In this case, all intermediate activations from fwd that are required for bwd will serve as inputs to the bwd graph. These inputs cannot be released during the execution of bwd. Since bwd consumes some GPU memory, this will lead to higher than expected memory usage.

YangFei1990 commented 10 months ago

@yitongh

  1. We have not tested using send/recv so I do not know how much performance difference. But as you said it is not the major concern.
  2. For our implementation, each forward/backward execution are separate graphs but we do consider to fuse them for interleaved pipeline schedule. But each microbatch will be a separate graph. If you fuse multiple microbatches into the same graph, could the compiler handle the CC dependency properly? Also for activation memory, it does not matter whether you fuse the microbatches together. Input/output of certain mb forwards can only be released after the corresponding backward is called, even in the same graph, these memory will be preserved. As I mentioned previously this imbalance of activation memory usage is the nature of pipelining and can not be avoided.
yitongh commented 10 months ago

@YangFei1990 The compiler is able to correctly handle dependencies between communication operators due to the presence of tokens. As for the memory issue, if we separate the fwd graph and the bwd graph, we will use more GPU memory than if fwd and bwd were in the same graph. This is because when fwd and bwd are in different graphs, the activations from fwd cannot be released during bwd. These activations are not just the outputs of fwd but also include some intermediate activations required by bwd, such as the outputs of gemm. This issue is related to XLA and not a problem with pipelining.

yitongh commented 10 months ago

@YangFei1990 Allow me to provide a more precise explanation of the issue. During normal execution, intermediate activations from fwd are gradually released during the bwd process, and the GPU memory of these activations can be reused by operators in bwd. However, when we separate the fwd and bwd graphs, this reuse is broken because these activations become inputs to the bwd graph. Under the current mechanism of XLA, these inputs cannot be released or reused at runtime.

YangFei1990 commented 10 months ago

@yitongh I see. To make sure I understood correctly, you are saying that for a certain microbatch's backward execution, the input activations can not be gradually released due to that these activations are used as the input of this backward graph? If so yeah there will be at most 1 microbatch's FWD activation stored at the end of the backward execution. Have you tried the activation checkpointing? This issue should be resolved if you checkpoint the activations since they will be calculated on the fly together with the backward pass.

yitongh commented 10 months ago

@YangFei1990 Yes, we have experimented with activation checkpointing, but PP is generally not used in isolation. When we combine PP with TP, we cannot apply gradient checkpointing to the entire PP stage (or decoder layer), as it would increase the communication overhead for TP.

YangFei1990 commented 10 months ago

@yitongh Why it is TP related? I think TP is also commonly running with activation checkpointing, the recomputation is the same as the original forward.

yitongh commented 10 months ago

@YangFei1990 When we use activation checkpointing, we typically don't apply it to the entire decoder layer, but only to the core attention, such as the selective activation checkpointing in Megatron, because larger granularity in activation checkpointing would increase the communication overhead for TP. The aforementioned memory issue still persists when using activation checkpointing solely on the core attention.

YangFei1990 commented 10 months ago

@yitongh I see. We usually combine following techniques

and it works for most of our use cases. Currently our PP implementation always isolates the P2P comm op in a separate graph, but now I would consider to add options to remove the mark_step so the graphs across mbs can be fused together. So far the activation memory pressure is mostly coming from the imbalance of the pipelining, and will become even worse for the interleaved scheduler. Just out of curiosity could you share what is your model size and how much extra memory that can not be released due to graph isolation?

yitongh commented 10 months ago

@YangFei1990 The model sizes we use range from 7 billion to 175 billion. In many cases, we cannot employ the same 3D parallelism strategy as Megatron due to the extra GPU memory consumed by PP. I haven't measured this extra memory usage in detail, but in certain scenarios, it can reach 10GB+ or more.

YangFei1990 commented 10 months ago

@yitongh That is surprising. Not sure if I missed anything, since the extra memory is up to 1 microbatch's activation, if it is up to 10GB then the pipeline imbalance will already blow up. For example in Megatron config 175B is running with PP 16 so on first PP rank there will be 16 warmup microbatches' activation memory accumulation, which will take 160G already.

yitongh commented 10 months ago

@YangFei1990 This issue only affects the bwd of a single micro batch. When using the PipeDream's 1F1B algorithm, rank 0 consumes the most GPU memory, which reaches peak memory after N fwd. In Megatron, we do not need to consider the additional memory overhead during the bwd. However, in XLA, we need to account for the memory usage of N fwd + 1 bwd, where the bwd can add over 10GB of extra memory. Overall, in Megatron, the peak GPU memory usage is reached after N fwd, whereas in XLA, it is reached after N fwd + 1 bwd. This necessitates an increase in the size of TP, which, due to the communication overhead associated with TP, leads to a decline in performance.

wconstab commented 10 months ago

Could anyone comment on the blockers/gaps of using the latest PipPy for ptxla?

YangFei1990 commented 10 months ago

Could anyone comment on the blockers/gaps of using the latest PipPy for ptxla?

@wconstab I have not carefully examined the Pippy package, from a high level I think what is missing

And in general handle anything that is expecting an eager behavior but involves some lazy tensor.

wconstab commented 10 months ago

The meta init part we should handle.

I don't know if it's necessary for pippy to know about xla device. What about using meta device to trace pipeline stages and then using ptxla to run them with lazy tensor? You'd retrace each stage graph but not couple the pp tracer logic to the backend.

YangFei1990 commented 10 months ago

@wconstab The solution proposed in this RFC (implemented in NeuronxDistributed) is doing what you described exactly. If you look into the our code only the components mentioned above are XLA related which are easily to be decoupled, and most components are purely Pytorch. Many UX optimizations we are making, e.g. automatic partition, are totally hardware agnostic and can be applied to Pippy as well. The motivation of this RFC is to have a unique pipeline parallelism solution that works both for eager execution and XLA. So we both can contribute and it is easy for user to switch between. Otherwise we can also upstream our solution to PT/XLA but many development effort might be duplicated. cc @JackCaoG

wconstab commented 10 months ago

The motivation of this RFC is to have a unique pipeline parallelism solution that works both for eager execution and XLA

hm, could you upstream to PipPy instead? would be nice to have a uniform UX for PP regardless of backend. cc @kwen2501

YangFei1990 commented 10 months ago

The motivation of this RFC is to have a unique pipeline parallelism solution that works both for eager execution and XLA

hm, could you upstream to PipPy instead? would be nice to have a uniform UX for PP regardless of backend. cc @kwen2501

Yes @wconstab that is the proposal of this RFC, assuming Mete will use Pippy as its official way to do PP for Pytorch.

miladm commented 4 months ago

@YangFei1990 checking in to follow up on this effort. Do we have updates to discuss?