Open YangFei1990 opened 10 months ago
FYI @wconstab @kwen2501
Currently we use FX's symbolic tracer to do tracing. We have plan to move tracing into torch Dynamo for better model coverage and potentially better performance.
PiPPy migrated to PT2 tracer too :) See https://github.com/pytorch/PiPPy/pull/873
Currently we use FX's symbolic tracer to do tracing. We have plan to move tracing into torch Dynamo for better model coverage and potentially better performance.
PiPPy migrated to PT2 tracer too :) See pytorch/PiPPy#873
Yeah we would like to do the same once the XLA CC ops are traceable for Dynamo.
@alanwaketan I think there is no real blocker for our cc ops to be traceable even with groups
? I remembered we didn't implement it because we didn't need it at that time.
@alanwaketan I think there is no real blocker for our cc ops to be traceable even with
groups
? I remembered we didn't implement it because we didn't need it at that time.
Right, there is no blockers. There was no immediate needs in our ends back then and now our distributed story is shifted to focus on SPMD. So, feel free to contribute to it, @YangFei1990! If you need anything, I can share the traceable cc design doc with you.
Thanks @YangFei1990 !
Hello, @YangFei1990. We have also implemented pipeline parallelism on GPU based on XLA. In essence, we trace the user's model using torch.fx symbolic trace and split it into PP stages based on the split points provided by the user. Then we run the split fx graph, and finally, like other torch modules, we lower it to XLA through the LTC mechanism. We implemented the PipeDream algorithm for PP based on DeepSpeed's PP implementation. We did not use PiPPy because PiPPy relies on RPC for scheduling, which is heavy. Currently, in our PP implementation, we encounter the following issues:
We need to implement send/recv XLA operations. Currently, the support for send/recv operators in PyTorch/XLA at the python is not yet complete. On the GPU side, we need to pass the destination rank for send and the source rank for receive. To use them together with other communication operators, we need to use tokens. However, XLA does not yet support tokens when implementing operators like all-reduce; it only supports numeric tokens. The tokens for send/recv are not numeric, which means we either need to make operators like all-reduce support tokens, or make send/recv support numeric tokens.
During the execution process of pipeline parallelism, we will run many micro-batches, for instance, 64. A higher number of micro-batches can help mask the bubble problem inherent to PP. In large language model scenarios, this issue results in an excessively large graph if we trace the entire PP execution process at once, leading to very long compilation times. However, if we compile after running a few micro-batches of fwd and bwd, due to the PipeDream algorithm, some of the activations from the fwd passes will be used as inputs for subsequent bwd passes. This causes these activations to remain unreleased during the execution of bwd, leading to an increase in GPU memory usage. This is also why we need to implement send/recv in XLA. If we don't, we would have to compile and execute each fwd and bwd pass for every micro-batch, leading to increased GPU memory usage.
Have you experienced any of the issues mentioned above? Very happy to have the opportunity to discuss them together.
I have opened an issue regarding the token issue of send/recv and other communication operators in XLA. Welcome to join the discussion.
Hi @yitongh
Question for both @yitongh and @YangFei1990, is each pipeline stage HLO identical in your use case?
@JackCaoG Each pipeline stage should have different graph. Naturally it is MPMD, e.g embedding only on the first PP rank and lm_head only on the last on.
@JackCaoG Just like @YangFei1990 , the HLO at each of pipeline stages is different.
@YangFei1990 Thanks for your reply.
When using all reduce for communication, how is your communication performance compared to using send/recv? With send/recv, an async send is an operation that can be ignored for the current stage, and recv is also a cheaper operator than all reduce. However, I believe that in the PP scenario, even if all reduce is used for PP communication, the performance issue is not significant, and the main bottleneck is still the bubble.
I'm curious about the PP algorithm you're using and how many micro batches you run PP with. Yes, we use the 1F1B scheduling strategy from PipeDream. Under this strategy, we can enable a relatively large number of micro batches, so that the bubble issue in PP becomes negligible, such as when running 128 micro batches with a PP count of 4. In this scenario, we need to consider how to add mark_step. If we add a mark_step only after all micro batches have been executed, we will end up with a very large graph. If we add a mark_step after running a few 1F1B cycles, because the fwd and bwd in 1F1B are not from the same micro batch, some of the temporary activations from fwd will be used by subsequent bwd, causing these activations to serve as inputs for bwd and preventing their release during the execution of bwd. Let's consider a simple scenario where we add a mark_step after fwd and another mark_step after bwd. In this case, all intermediate activations from fwd that are required for bwd will serve as inputs to the bwd graph. These inputs cannot be released during the execution of bwd. Since bwd consumes some GPU memory, this will lead to higher than expected memory usage.
@yitongh
@YangFei1990 The compiler is able to correctly handle dependencies between communication operators due to the presence of tokens. As for the memory issue, if we separate the fwd graph and the bwd graph, we will use more GPU memory than if fwd and bwd were in the same graph. This is because when fwd and bwd are in different graphs, the activations from fwd cannot be released during bwd. These activations are not just the outputs of fwd but also include some intermediate activations required by bwd, such as the outputs of gemm. This issue is related to XLA and not a problem with pipelining.
@YangFei1990 Allow me to provide a more precise explanation of the issue. During normal execution, intermediate activations from fwd are gradually released during the bwd process, and the GPU memory of these activations can be reused by operators in bwd. However, when we separate the fwd and bwd graphs, this reuse is broken because these activations become inputs to the bwd graph. Under the current mechanism of XLA, these inputs cannot be released or reused at runtime.
@yitongh I see. To make sure I understood correctly, you are saying that for a certain microbatch's backward execution, the input activations can not be gradually released due to that these activations are used as the input of this backward graph? If so yeah there will be at most 1 microbatch's FWD activation stored at the end of the backward execution. Have you tried the activation checkpointing? This issue should be resolved if you checkpoint the activations since they will be calculated on the fly together with the backward pass.
@YangFei1990 Yes, we have experimented with activation checkpointing, but PP is generally not used in isolation. When we combine PP with TP, we cannot apply gradient checkpointing to the entire PP stage (or decoder layer), as it would increase the communication overhead for TP.
@yitongh Why it is TP related? I think TP is also commonly running with activation checkpointing, the recomputation is the same as the original forward.
@YangFei1990 When we use activation checkpointing, we typically don't apply it to the entire decoder layer, but only to the core attention, such as the selective activation checkpointing in Megatron, because larger granularity in activation checkpointing would increase the communication overhead for TP. The aforementioned memory issue still persists when using activation checkpointing solely on the core attention.
@yitongh I see. We usually combine following techniques
and it works for most of our use cases. Currently our PP implementation always isolates the P2P comm op in a separate graph, but now I would consider to add options to remove the mark_step so the graphs across mbs can be fused together. So far the activation memory pressure is mostly coming from the imbalance of the pipelining, and will become even worse for the interleaved scheduler. Just out of curiosity could you share what is your model size and how much extra memory that can not be released due to graph isolation?
@YangFei1990 The model sizes we use range from 7 billion to 175 billion. In many cases, we cannot employ the same 3D parallelism strategy as Megatron due to the extra GPU memory consumed by PP. I haven't measured this extra memory usage in detail, but in certain scenarios, it can reach 10GB+ or more.
@yitongh That is surprising. Not sure if I missed anything, since the extra memory is up to 1 microbatch's activation, if it is up to 10GB then the pipeline imbalance will already blow up. For example in Megatron config 175B is running with PP 16 so on first PP rank there will be 16 warmup microbatches' activation memory accumulation, which will take 160G already.
@YangFei1990 This issue only affects the bwd of a single micro batch. When using the PipeDream's 1F1B algorithm, rank 0 consumes the most GPU memory, which reaches peak memory after N fwd. In Megatron, we do not need to consider the additional memory overhead during the bwd. However, in XLA, we need to account for the memory usage of N fwd + 1 bwd, where the bwd can add over 10GB of extra memory. Overall, in Megatron, the peak GPU memory usage is reached after N fwd, whereas in XLA, it is reached after N fwd + 1 bwd. This necessitates an increase in the size of TP, which, due to the communication overhead associated with TP, leads to a decline in performance.
Could anyone comment on the blockers/gaps of using the latest PipPy for ptxla?
Could anyone comment on the blockers/gaps of using the latest PipPy for ptxla?
@wconstab I have not carefully examined the Pippy package, from a high level I think what is missing
And in general handle anything that is expecting an eager behavior but involves some lazy tensor.
The meta init part we should handle.
I don't know if it's necessary for pippy to know about xla device. What about using meta device to trace pipeline stages and then using ptxla to run them with lazy tensor? You'd retrace each stage graph but not couple the pp tracer logic to the backend.
@wconstab The solution proposed in this RFC (implemented in NeuronxDistributed) is doing what you described exactly. If you look into the our code only the components mentioned above are XLA related which are easily to be decoupled, and most components are purely Pytorch. Many UX optimizations we are making, e.g. automatic partition, are totally hardware agnostic and can be applied to Pippy as well. The motivation of this RFC is to have a unique pipeline parallelism solution that works both for eager execution and XLA. So we both can contribute and it is easy for user to switch between. Otherwise we can also upstream our solution to PT/XLA but many development effort might be duplicated. cc @JackCaoG
The motivation of this RFC is to have a unique pipeline parallelism solution that works both for eager execution and XLA
hm, could you upstream to PipPy instead? would be nice to have a uniform UX for PP regardless of backend. cc @kwen2501
The motivation of this RFC is to have a unique pipeline parallelism solution that works both for eager execution and XLA
hm, could you upstream to PipPy instead? would be nice to have a uniform UX for PP regardless of backend. cc @kwen2501
Yes @wconstab that is the proposal of this RFC, assuming Mete will use Pippy as its official way to do PP for Pytorch.
@YangFei1990 checking in to follow up on this effort. Do we have updates to discuss?
🚀 Description
Pipeline parallelism is a technique used in deep learning model training to improve efficiency and reduce the training time of large neural networks. Here we propose a pipeline parallelism solution for Pytorch/XLA with the following goals
Solution
NeuronxDistributed(a pytorch/xla based library) offers a FX partition based pipeline parallelism solution which satisfies both objectives mentioned above. The solution traces the entire model to a static FX graph and partitions the graph into a series of submodules based on user annotations. FX’s GraphModule is an
nn.Module
generated from anfx.Graph
.GraphModule
has agraph
attribute, as well as code andforward
attributes generated from thatgraph
. The output of partition is a list of graph modules that each belongs to a PP rank. These graph modules can be executed in the same way as native torch.nn.Modules and traced by XLA. During runtime, each PP rank will only run the graph module that belongs to it with the predefined PP schedule. Such execution will be traced by XLA and compiled into executable programs and then executed on XLA device. In general there is no fundamental difference compared with how XLA device train models today.Interacting with other parallelism techniques
Tensor parallelism
Since Tensor-parallelism operates within a module and since the current PP solution cuts at the layer boundary by keeping the module untouched, the PP solution can work seamlessly with TP. The same solution is implemented inside the Neurox-Distributed library, that uses megatron style TP APIs and FX based PP.
ZeRO/FSDP
We have support with ZeRO1 optimizer, i.e. optimizer state sharding. For ZeRO2/3/FSDP we do not support yet.
GSPMD
The way GSPMD works is to mark the tensor as sharded and the CC ops will be inserted during XLA passes. Since our solution is doing PP partition purely on Pytorch framework layer, the sharded tensor should remain unchanged after PP partition. The PP send/recv op will be in different graph compared with the model execution graph, so that GSPMD can insert corresponding CC ops into the model execution graph. Based on above analysis, we envision non-significant effort required to support GSPMD with our solution. For example, we might need to build support when certain tensor that is marked as sharded will be communicated between PP stages.
APIs
Use cases
Supporting pipeline parallelism for a generic model is challenging. It becomes a much simpler problem if we focus on Transformer based models that contain a linear list of modules/ModuleList/Sequential, with the observation that most PP workloads are Transformer based large language models(LLM). Such assumption not only makes it easy to implement the pipeline partition and runtime, but also relax certain tracing overhead. While we are focusing on transformer based models, the design principle itself is generic and could be extended to any model architectures.
Current/Future work
Proposal
We would like to understand the plan and roadmap of Pytorch team for PP support and seek for collaboration/integration. Though NeuronxDistributed's PP solution is aimed to run on XLA device, but most of the components are hardware agnostic and it is easy to decouple the component that requires special XLA treatment, mainly the communication part. It will be beneficial to have a Pytorch first party PP solution that can work for both GPU and XLA device. We have observed that Pytorch has been developing a first party solution for pipeline parallelism, namely Pippy. Pippy embrace the same idea as our solution, i.e. tracing and do PP partition on FX graph. Previously Pippy's runtime is based on Pytorch RPC, which makes it hard to run on XLA devices. However recently we found that there is a refactor of Pippy which removed the RPC backend and switched to torch c10d. Such refactor makes it much easier to integrate with our solution.
Additional context
cc @JackCaoG @aws-rhsoln @amithrm Please also tag folks that are relevant.