pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.39k stars 430 forks source link

[RFC] Support spmd on GPU #6256

Open vanbasten23 opened 6 months ago

vanbasten23 commented 6 months ago

Sharing the current design for spmd on GPU for pytorch/xla. Feel free to suggest and comment.

🚀 Objective

This design is intended to describe what is needed to make GSPMD work in PyTorch/XLA on the GPU. By doing so, we can enable large scale PyTorch training via GSPMD and leverage the compiler-based sharding framework/tools.

Goals

Non-goals

Design

Usability/User experience

It is a trend that today’s machine learning models contain a significant number of parameters (~Billions) and data. It is likely that users will use SPMD on multiple GPU machines. This is true for TPU. So we will design the user interface around that assumption.

With that in mind, we can make the SPMD user experience similar to that of multi-host training that we have so far. In multi-host training on GPU, users run a torchrun script such as

PJRT_DEVICE=CUDA \
torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_IP_ADDRESS>:12355" test_pytorch_xla_script.py

Similarly, we propose that for SPMD training, users can still use torchrun and start the training such as:

PJRT_DEVICE=CUDA \
torchrun \
--nnodes=${NUM_GPU_MACHINES} \
--node_rank=${RANK_OF_CURRENT_MACHINE} \
--nproc_per_node=1 \ # each machine create one process due to SPMD requirement
--rdzv_endpoint="<MACHINE_0_IP_ADDRESS>:12355" \
test_pytorch_xla_spmd_script.py

The only notable difference is how we set the flag --nproc_per_node. In the general multinode training, --nproc_per_node indicates how many processes I want to use on the current node and it can be as large as the total number of GPU devices on the node. However, for SPMD, it has to be the constant number of 1 because SPMD requires one process per node. All the others torchrun flags remain the same.

Then we enable SPMD mode via xr.use_spmd(). In summary, --nproc_per_node=1 and xr.use_spmd() enable the users to run PyTorch/XLA SPMD on GPU.

The benefits of this proposal is that we make the user experience consistent with PyTorch’s multinode training which predominantly uses torchrun. Also, we can leverage the GKE tooling that we developed for PyTorch/XLA multinode training and provide a workload manager (for example SLURM) like user experience, another user experience recommended by PyTorch. Also, by using xr.use_spmd(), we are on par with go/pytorch-spmd-usability in that we avoid using another environment variable XLA_USE_SPMD hence simplifies the user experience.

GPU client

After the user starts the SPMD training via either torchrun or GKE, the aforementioned torchrun command will be run. A few environment variables will be set on each host:

LOCAL_RANK: The local rank of the process will be equal to the rank of the current GPU machine.
RANK: The global rank of the process will be equal to LOCAL_RANK hence equal to the number of participant GPU machines. This is because SPMD creates 1 process per host.
LOCAL_WORLD_SIZE: The local world size (e.g. number of processes running locally) will be 1 since SPMD creates 1 process per host.
WORLD_SIZE: total number of processes across the hosts will be equal to the number of GPU machines.

Then we need to first make sure the single process on each GPU machine can access all GPU devices. In contrast, currently each process can only access one GPU device. To accomplish this, we need to construct the StreamExecutorGpuClient with the correct GpuClientOptions. Most notably, GpuClientOptions.allowed_devices needs to be empty so that the StreamExecutorGpuClient can automatically detect all GPU devices attached to the current node.

Process group

One of the concerns about using torchrun is how we deal with process groups. Process group is a group of processes to achieve one task and it enables communication among the processes. But the XLA device process group is not supported for SPMD because SPMD is a single replica and the compiler should handle the communication/coordination instead of users manually communicating/coordinating processes via collective ops. The problem is if a XLA process group is created (by torchrun or something else and we do dist.init_process_group, the code would crash under SPMD mode.

It turns out it is not an issue. For one thing, the process group on XLA is only created when import torch_xla.distributed.xla_backend module but SPMD script does not import the module. On the other hand, because a process group is not needed for SPMD, there is no reason to do dist.init_process_group in the SPMD training/inference script.

PyTorch/XLA technologies

The design should work well with dynamo/non-dynamo, inference, and FSPD v2 because the design does not change the SPMD interface in any way. In other words, those aforementioned technologies should work in a hardware agnostic way hence will be discussed in detail in other documents.

Future work

Eventually, we want to have the one-GPU-per-process model, even for SPMD, for performance’s sake. The one-GPU-per-process model is the optimal configuration because each process can then be bound to a NUMA domain and a single NIC. To enable it, we may need an overhaul of the existing SPMD design so it will come after the current work.

Performance

We will choose Llama 2 to run benchmarking since it is the one we benchmark in PyTorch/XLA SPMD on TPU. As a first step, we will compare the performance with JAX SPMD on GPU since the only variance is the ML framework (PyTorch/XLA vs JAX) and all the rest (XLA GPU compiler, hardware) will be the same. Later, we will compare the performance wil PyTorch on inductor, PyTorch eager, PyTorch/XLA SPMD on TPU, PyTorch/XLA FSDP.

Additional context

miladm commented 6 months ago

Thanks @vanbasten23! Can you share a pointer/discussion re: our plans to enable auto-sharding as a future direction?

vanbasten23 commented 6 months ago

Thanks @vanbasten23! Can you share a pointer/discussion re: our plans to enable auto-sharding as a future direction?

To my understanding, the auto-sharding feature should be device agnostic. Our design https://github.com/pytorch/xla/issues/6322 should work for both TPU and GPU. Enabling auto-sharding will come after we enable SPMD on GPU. We may need to make some change in XLA:GPU, similar to what @yeounoh does in XLA:TPU.

Priority-wise, I imagine we may want to see how the auto-sharding work for TPU and how the new SPMD on GPU perform on some benchmarking model on a LLM, then decide which advanced feature we want to do next.

cc: @yeounoh