pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.44k stars 450 forks source link

[RFC] PyTorch/XLA Auto-Sharding API #6322

Open yeounoh opened 7 months ago

yeounoh commented 7 months ago

🚀 Feature & Motivation

PyTorch/XLA recently launched PyTorch/XLA SPMD (RFC, blog, docs/spmd.md), as a first-step to automate ML workloads parallelization using GSPMD. In contrast to previous PyTorch user experiences with implementing tensor & model parallelism for their ML workloads, PyTorch/XLA SPMD allows the users to provide a handful of “sharding hints” via the PyTorch/XLA SPMD sharding annotation API, and keep the original model implementation as-is. We presented some exciting CloudTPU PyTorch LLaMA2 training results using SPMD at Google Next (blog). Another highlight of PyTorch/XLA SPMD was that it enables more advanced and hybrid types of parallelism strategies for the PyTorch users, combining data and tensor parallelism, as well as pipelining. While this is all great and we are happy to release the new feature to the PyTorch community, one challenge still remains as to providing the optimal sharding hints. It turns out that the performance largely depends on the quality of sharding hints provided by the user – and it requires a correct and deep understanding of model architectures and much expertise to come up with optimal sharding hints.

To address this problem, we propose to integrate PyTorch/XLA SPMD with XLA's auto sharding service that allows the XLA compiler to shard and optimize the whole model without any user input. XLA auto-sharding service is based on a published research work, Alpa (blog). While this sounds like a leap of faith, it is already being tested and showing some promising results on the Google internal workloads (also see our mini-benchmark results).

API and Usage Example

To enable auto-sharding, simply call use_spmd with the auto=True flag:

import torch_xla.runtime as xr 

# Enable XLA SPMD execution mode with auto-sharding.
xr.use_spmd(auto=True)

# Write a PyTorch program without any sharding annotations.
# User can still provide sharding hints, but optionally. 
...

This should be it, the PyTorch program should be automatically sharded and executed.

There are optional configuration knobs, though:

# The optional auto_sharding config can be passed to control 
# the auto-sharding behavior.
config = {"auto_sharding" : {"partitioner" : "alpa", "keep_user_sharding" : True }}
xr.use_spmd(auto=True, spmd_config = config)

The auto-sharding uses auto_sharder=”alpa” auto-partitioner, and it’s the only option available.

The auto-sharding runs with SPMD mode and should work with zero sharding hints. to work, and the execution is sharded and optimized by the XLA compiler. By default, the auto-sharding pass will respect pre-existing sharding annotations on the inputs and outputs; the user can choose to provide more hints using PyTorch/XLA SPMD mark_sharding API and setting keep_user_sharding option.

auto-sharding configuration

Here is the list of supported auto-sharding configuration options:

PyTorch DTensor integration

It is important to be plugged into the PyTorch distributed API for unified UX for PyTorch distributed [RFC]. DTensor is PyTorch's SPMD-style distributed tensor computation API, where PyTorch/XLA SPMD is integrated with. We propose to introduce a new auto partition function to the DTensor distribute_module API with partition_fcn="AUTO":

import torch
import torch_xla.core.xla_model as xm
from torch.distributed import DeviceMesh, distirbute_module

# Define a PyTorch module
...
my_module = MyModule().to(xm.xla_device())

# Automatically sharded (annotated) module for XLA distributed execution
mesh = DeviceMesh("xla", list(range(world_size)))
my_sharded_module = distribute_module(my_module, mesh, partition_fcn="AUTO")

Mini-Benchmark Results

Here we present preliminary benchmark results usign GPT-2, LLaMA2 and GPT-Neo from HuggingFace. Auto-sharding works to parallelize and distribute any transformer-based language models without user sharding annotations on the models. We used PyTorch/XLA’s MpDeviceDataLoader for background data loading with batch dimension sharding.

![gpt2_v4_8_mfu_batch](https://github.com/pytorch/xla/assets/7146489/8248665d-89e0-45a4-916b-a854a457597f) ![gpt2_2b_step_time_vs_batch](https://github.com/pytorch/xla/assets/7146489/e2f9afc8-a639-4308-8205-ae59f658912a)

A preliminary benchmark result based on GPT-2 (2B parameters) model on TPUv4-8 shows that the auto-sharding pass generates comparable results with the human-curated 2D sharding strategies:

![llama2_2b_bsz128](https://github.com/pytorch/xla/assets/7146489/ab02a9f2-4876-4c53-a230-f81425180265) ![perf_auto_vs_manual](https://github.com/pytorch/xla/assets/7146489/5490643a-e044-41a3-a4e0-ea8858a11d22)

The above figures show that (left) auto-sharding doesn’t always, in case of LLaMA2, generate shardings for the best performance (MFU) while still producing performant ones. It is important to note that (right) it did work with three popular models from HuggingFace without customizing or manual annotations.

Alternatives

This work is to automate model parallelization using PyTorch/XLA SPMD, allowing the XLA compiler to come up with the optimal sharding strategies on behalf of the user. Alternatively, we will introduce a high-level API (e.g., FSDP) that iteratively calls PyTorch/XLA SPMD for a given policy RFC. Our goal is to provide useful tools for the PyTorch users for good optionalities.

Additional Context

Alpa is still an experimental feature, and it works for XLA supported HW types, like TPU and GPU -- We hope to provide a singular approach for any PyTorch/XLA backend types. In the near future, we will also expand the choice of auto-sharding algorithms outside Alpa as well.

cc @JackCaoG @miladm @shauheen @alanwaketan @baoleai @anw90 @yitongh @Seventeen17 @wconstab @wanchaol

yeounoh commented 7 months ago

This was meant to be a draft, I am still actively editing lol EDIT: ok it's ready for review & further discussion.

wconstab commented 7 months ago

I would like to see more detail on how users specify manual placement for xla via dtensor.

Then, I wonder why adding an auto-placement to dtensor is required. Why can't auto be assumed if no manual sharding is specified?

yeounoh commented 7 months ago

I would like to see more detail on how users specify manual placement for xla via dtensor.

Yes, will follow up with a tutorial with @wanchaol

Then, I wonder why adding an auto-placement to dtensor is required. Why can't auto be assumed if no manual sharding is specified?

I would say that's mainly because in non-auto SPMD, users are not required to provide hints for all tensors. It really comes down to using the user-provided hints + sharding propagation vs. allowing XLA to search for better sharding hints. Let's have more talk, thanks @wconstab !

miladm commented 7 months ago

Awesome! @yeounoh can you please add a few words on XLA:GPU vs. XLA:TPU auto-sharding plan/capabilities and plans to have a singular approach for all HW backends?

yeounoh commented 7 months ago

Awesome! @yeounoh can you please add a few words on XLA:GPU vs. XLA:TPU auto-sharding plan/capabilities and plans to have a singular approach for all HW backends?

+1 added some additional context around it.

alanwaketan commented 7 months ago

This looks great. Wondering if you have tested it against non-LLMs? And also wondering if you have test results beyond a single TPU host.

hyviquel commented 4 months ago

@yeounoh do you plan to add support for auto-sharding to GPUs as well?

yeounoh commented 4 months ago

@yeounoh do you plan to add support for auto-sharding to GPUs as well?

Hi @hyviquel -- yes, it's in our roadmap. It's currently blocked on the XLA:GPU side, and we will work to enable it.

hyviquel commented 4 months ago

Hi @hyviquel -- yes, it's in our roadmap. It's currently blocked on the XLA:GPU side, ad we will work to enable it.

Nice, thanks for the quick answer.

Will the forthcoming improvements directly enable auto-sharding across multiple GPU nodes, similar to the ALPA and Ray combination, or is further development required to match that level of functionality?

yeounoh commented 4 months ago

Hi @hyviquel -- yes, it's in our roadmap. It's currently blocked on the XLA:GPU side, ad we will work to enable it.

Nice, thanks for the quick answer.

Will the forthcoming improvements directly enable auto-sharding across multiple GPU nodes, similar to the ALPA and Ray combination, or is further development required to match that level of functionality?

Once we are unblocked for the GPU support, then it should work for multiple GPU nodes. The auto-sharding feature is based on SPMD, which we are working to enable for XLA:GPU as well https://github.com/pytorch/xla/issues/6256. Having said that, I expect further development & optimization work would be required for fully functional XLA:GPU.

hyviquel commented 2 weeks ago

@vanbasten23 I noticed you just closed #6256 , does it mean autosharding should work for XLA:GPU now?

i-Pear commented 1 week ago

@vanbasten23 I noticed you just closed #6256 , does it mean autosharding should work for XLA:GPU now?

@hyviquel I’ve recently been exploring auto sharding on GPUs and have gained some experience that I’d like to share.

  1. This feature relies on the OpenXLA Alpa algorithm, which isn’t enabled by default due to issues with TensorFlow's linker [1]. However, with a few manual patches in [1], it can be enabled with torch-xla, and the algorithm can work correctly.

  2. There are some compatibility issues between OpenXLA’s GPU compiler and the Alpa algorithm interface, causing the sharding algorithm to retrieve incorrect device information. I’m going to submit a patch to OpenXLA to fix this.

  3. There are some compilation errors in the Alpa algorithm implementation, likely due to recent dependency changes. Fortunately, these are not difficult to fix.

Once the above issues are addressed, auto sharding should work on GPUs.

[1] https://github.com/openxla/xla/pull/13952

hyviquel commented 1 week ago

Nice thanks @i-Pear, please keep us up to date to the progress! I would be happy to help for testing it once it's ready.