pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.47k stars 471 forks source link

[RFC] A high-level GSPMD API in PT/XLA (based on `xs.mark_sharding`) #3755

Open ronghanghu opened 2 years ago

ronghanghu commented 2 years ago

šŸš€ [RFC] A high-level GSPMD API in PT/XLA (based on xs.mark_sharding)

This RFC proposes a high-level API for GSPMD through a wrapper class and a partitioning rule function, based on xs.mark_sharding.

Motivation

GSPMD is a powerful approach for model sharding and parallelization, and is enabled in PyTorch/XLA by the xs.mark_sharding API in https://github.com/pytorch/xla/pull/3476 and https://github.com/pytorch/xla/pull/3684. However, there are a few limitations when directly using this API to build SPMD programs:

Ideally, one would like to be able to take any existing PyTorch models (e.g. BERT from Hugging Face), and apply a specific sharding strategy to it (e.g. Megatron), without changing the model code (in Hugging Face library code in this case).

In this RFC, we advocate for a high-level SPMD API in PyTorch/XLA (built upon xs.mark_sharding) without requesting the user to rewrite their entire model code to be aware of the SPMD sharding partitions. Specifically, we propose to parallelize a model with SPMD by first building the base model and then wrapping it with a wrapper class GSPMDParallel, similar to how DDP or FSDP is applied to an existing model.

The proposed GSPMDParallel class takes as input

  1. a given PyTorch module (the base model, an nn.Module instance), and
  2. a user-specified function/callable sharding_rule_func to define the sharding rules, and
  3. a mesh_shape tuple to define the TPU mesh.

Since those tensors we want to shard in SPMD are usually either the parameters or the input/output tensors of a submodule (e.g. an nn.Linear) in the whole network, the GSPMDParallel class will recursively go down to all the submodules in the wrapped model and apply the sharding rules to the parameters, inputs, and outputs of each submodule. The sharded HLO graph will be built when forward (and backward) is called on the GSPMDParallel class.

In this way, the base model code doesn't need to be changed when using SPMD -- it will just be wrapped and partitioned in a post-hoc manner, so all the current model implementations from the abundant existing PyTorch libraries (e.g. timm, torchvision, Hugging Face) can be directly used for SPMD parallelism without rewriting these libraries.

Proposed Implementation

Building upon the xs.mark_sharding API, the proposed implementation (prototype) of the high-level SPMD API consists of a GSPMDParallel class that recursively applies a user-defined function/callable sharding_rule_func to the submodules of an input module, which can be based on their names and the submodule instances themselves.

A prototype is as follows:

class GSPMDParallel(nn.Module):
  """recursively apply a `sharding_rule_func` to submodules of an input `module`."""

  def __init__(self, module: nn.Module, sharding_rule_func: Callable, mesh_shape: Tuple[Union[int, None]]):
      super().__init__()

      # apply SPMD sharding rule to the base model
      sharded_module = self.apply_sharding(module, sharding_rule_func, mesh_shape)
      self.module = sharded_module
      self.sharding_rule_func = sharding_rule_func
      self.mesh_shape = mesh_shape

  def apply_sharding(self, module, sharding_rule_func, mesh_shape):
      sharded_module = deepcopy(module)  # maybe make a copy if we want to keep the original `module`

      # recursively apply the sharding rule to all the submodules
      for name, m in module.named_modules():
          sharding_rule_func(name, m, mesh_shape)

      return sharded_module

  def forward(self, *args, **kwargs):
      return self.module(*args, **kwargs)

And a user-defined sharding rule function/callable will decide whether and how to apply SPMD sharding annotations to an nn.Module's parameters, inputs, and outputs with xs.mark_sharding. For example, the following sharding rule function can be used to apply the Megatron sharding to those MLP layers in a timm Vision Transformer (ViT).

def example_sharding_rule_func(name: str, submodule: nn.Module, mesh_shape: Tuple[Union[int, None]]):
    """apply Megatron to MLP layers in timm ViT. Assuming (data, model) `mesh_shape` like T5X."""

    if name.endswith("blocks.mlp.fc1"):
      assert isinstance(submodule, nn.Linear)
      # shard the 1st MLP layer's weight param (mlp_dim, hidden_size)
      submodule.weight = xs.mark_sharding(submodule.weight, mesh_shape, (1, None))
      # shard the 1st MLP layer's bias param (mlp_dim,)
      submodule.bias = xs.mark_sharding(submodule.bias, mesh_shape, (1,))

      # shard the 1st MLP layer's output (batch_size, seq_length, mlp_dim) by patching its forward
      # TODO (change to decorators on `forward` to get cleaner code)
      submodule._orig_forward = submodule.forward
      def _new_forward(m, x):
          return xs.mark_sharding(m._orig_forward(x), mesh_shape, (0, None, 1))
      submodule.forward = MethodType(_new_forward, submodule)

    elif name.endswith("blocks.mlp.fc2"):
      assert isinstance(submodule, nn.Linear)

      # shard the weight (hidden_size, mlp_size) in the 2nd MLP layer
      submodule.weight = xs.mark_sharding(submodule.weight, mesh_shape, (None, 1))

In the example above, it is straightforward to shard a submodule's parameter tensors, but rather hacky to shard its input and output tensors (by manually patching the forward method). We can switch to decorators to get a cleaner code.

We could also provide a few easier ways to build sharding_rule_func or to use it in GSPMDParallel. For example, we can make sharding_rule_func a callable class instead of a function, and provide a good class structure to build the sharding rule callables.

Alternatives

The implementation above requires a sharding_rule_func (that takes a submodule object and its name) to decide how to decorate its parameters, inputs, or return values using the xs.mark_sharding API. While this should be sufficient to implement nearly all SPMD use cases, it is hard to later inspect the sharding annotations in a GSPMDParallel instance, such as printing a list of sharded tensors annotated by xs.mark_sharding and their partitioning details. It relies on the user to keep track of what is sharded by sharding_rule_func.

An alternative way to implement this GSPMDParallel class is to enforce a more principled way to define the sharding rule. Rather than having an arbitrary function to do anything, one can have a name-based (string-based) sharding rule similar to the logical axis names in T5X. Under this name-based sharding rule definition, a sharding rule consists of the following:

1) a function/callable to extract those tensors to be sharded (param, input tensor, return values) from a submodule and map their tensor axes to logical axis names (such as (batch, mlp_size)).

2) a user-specified mapping rule (e.g. a list) to map a logical axis to a TPU mesh axis, similar to those in T5X.

An orthogonal and complementary way to simplify the sharding rule is to use Named Tensors API to give each tensor axis a name (a string). If the users need to implement both their base model (to be wrapped by GSPMDParallel) as well as their sharding_rule_func, then it would be easier for them to use named tensors in their base model implementation and refer to those axis names in their sharding_rule_func implementation. However, this approach cannot be applied to existing models (e.g. those in torchvision or Hugging Face) that don't use the named tensors, so named tensors should not be a requirement in the GSPMDParallel class.

Additional context

A related problem is how to save and load an SPMD partitioned model's parameters and optimizer state dicts, especially in those cases where the full model cannot fit into a single TPU VM's host memory. This part requires a mechanism to save and load checkpoints in a distributed manner without consolidating them onto a single host.


On our end (FAIR), we are happy to work on a prototype implementation of the GSPMDParallel class above and first try it out in our internal use cases. We can submit a PR once we have a mature implementation.

cc: @yeounoh @JackCaoG @miladm @ultrons

yeounoh commented 2 years ago

Thank you @ronghanghu!

yeounoh commented 2 years ago

Hi @ronghanghu , this may or may not be relevant. The above examples seem correct, but for what it's worth; mark_sharding returns an XlaShardedTensor, but it's just a wrapper around the original tensor -- mark_sharding still passes the tensor to the backend with the sharding spec. This means that we don't need to replace the module layer/params with the output or mark_sharding. Here is an example:

device = xm.xla_device()
# load the predefined model
model = get_model_property('model_fn')().to(device)
# mark sharding
for name, layer in model.named_modules():
  if 'conv' in name:
    partition = int(xm.xrt_world_size() / 2)
    xs.mark_sharding(layer.weight, (1, 1, partition, partition), (0, 1, 2, 3))
ronghanghu commented 2 years ago

mark_sharding returns an XlaShardedTensor, but it's just a wrapper around the original tensor -- mark_sharding still passes the tensor to the backend with the sharding spec. This means that we don't need to replace the module layer/params with the output or mark_sharding.

@yeounoh This is great to know and makes a lot of things easier :) Thanks for the update!

In this scenario, what is the difference between using the output of xs.mark_sharding vs directly using the original tensor? My understanding is that xs.mark_sharding returns a XLAShardedTensor that has special __torch_function__ to handle subsequent tensor ops. In the snippet above, if we don't assign the output of xs.mark_sharding back to layer.weight, would subsequent computation using layer.weight follow our sharding annotations?

If we can do

y = x * 2
xs.mark_sharding(y, (1, 1, partition, partition), (0, 1, 2, 3))
# Case 1: using the existing tensor `x` instead of the output of `xs.mark_sharding`
z = y + 1

or even

y = x * 2
z = y + 1
# Case 2: like Case 1 but running `mark_sharding` at the end
xs.mark_sharding(y, (1, 1, partition, partition), (0, 1, 2, 3))

instead of

y = x * 2
y_shard = xs.mark_sharding(y, (1, 1, partition, partition), (0, 1, 2, 3))
# Case 3: using the output of `xs.mark_sharding`
z = y_shard + 1

then it makes the entire life much easier since we can still have the same python object id, etc. This allows us to e.g. first construct the entire model and the optimizer, and then just mark a few of its outputs.

I wonder how would Case 1 work exactly. Is y still a regular torch.Tensor or a XLAShardedTensor? And should xs.mark_sharding return None if this is just an in-place annotation?

yeounoh commented 2 years ago

All three scenarios are possible, since the annotation is first attached to the original tensor and XLAShardedTensor wraps around it. XLAShardedTensor may be needed more explicitly if the user wants to use its APIs (checking if it's sharded or access the local shards). We can apply native torch ops to XLAShardedTensor as if it's a regular torch.Tensor, since it dispatches the wrapped tensor to the backend.

For the first case, y is still a regular torch.Tensor but its IR node is sharding annotated (passed to the mark_sharding API). It's the same programming model as if nothing is sharded, but the computation on it will be partitioned/distributed. The user can use either y or y_shard as if it's on a single device. y_shard will provide access to the XLAShardedTensor APIs.

For the second case,

y = x * 2
z = y + 1
# Case 2: like Case 1 but running `mark_sharding` at the end
xs.mark_sharding(y, (1, 1, partition, partition), (0, 1, 2, 3))

If user executed (e.g., print) z before mark_sharding, then it would have been computed without sharding; the executions after mark_sharding will be partitioned.

ronghanghu commented 2 years ago

@yeounoh I see, this is great to hear and should make the SPMD execution much simpler. We probably won't even need a wrapper class at all, and perhaps a few annotation utilities will be enough.

Also, another question is that, are the annotations marked by xs.mark_sharding persistent through xm.mark_step, or do we need to mark them again every iteration? The application scenario is whether we need to call xs.mark_sharding on a parameter layer.weight in every training iteration, or do we just need to do it only once at the beginning of the training?

yeounoh commented 2 years ago

@yeounoh I see, this is great to hear and should make the SPMD execution much simpler. We probably won't even need a wrapper class at all, and perhaps a few annotation utilities will be enough.

Also, another question is that, are the annotations marked by xs.mark_sharding persistent through xm.mark_step, or do we need to mark them again every iteration? The application scenario is whether we need to call xs.mark_sharding on a parameter layer.weight in every training iteration, or do we just need to do it only once at the beginning of the training?

It will persist through xm.mark_step, so annotation is needed just once before the execution. Another thing I am going to implement is that block calling mark_sharding multiple times on the same tensor -- invariant: mark_sharding can be called on an unannotated tensor; if user wants to try a different sharding after a previous computation, then a clear_sharding needs to be called first.

I think we can still define an interface to apply function_rules and provide some example (commonly used) rules. Let's revise the proposal a bit together. I still see that this line of work will make things easier and more concise.

JackCaoG commented 2 years ago

It will persist through xm.mark_step

I am wondering how is this done. I thought sharding_spec is a field in the IR, upon mark_step we will clear all IR and replace them with device_data. Through I guess in this case device data will be in multiple device so it is still sharded.

yeounoh commented 2 years ago

@JackCaoG yea it's persisted through the xla_data, which is PjRtShardedData. I am thinking about using the IR as a single source of truth for the shaing annotation, and resetting the IR will check & preserve the sharding annotation if it exists. Let me follow up with you offline.

yeounoh commented 2 years ago

@ronghanghu I had a discussion with @JackCaoG , and we've decided to block mark_sharding to take effect if it's called after. It's always forward-looking and thus, appear eager.

y = x * 2
z = y + 1   # execute unpartitioned
# Case 2: like Case 1 but running `mark_sharding` at the end
xs.mark_sharding(y, (1, 1, partition, partition), (0, 1, 2, 3))
z = y + 1  # execute partitioned

I think we should continue working towards GSPMDParallel šŸ˜„ The sharding should still be preserved through mark_step.

ronghanghu commented 2 years ago

@yeounoh I see, sounds good and thanks for letting me know!

JackCaoG commented 2 years ago

for the real model, it should looks more like

input = torch.tensor()..
model = nn.linear()
xs.mark_sharding(model.parameters()[0], spec)           # not an expert, not sure if we need to do `.data`
xs.mark_sharding(input, spec)
res = model(input)
ronghanghu commented 2 years ago

@JackCaoG Cool, this usage looks good to me :)

ronghanghu commented 2 years ago

@yeounoh Another relevant usage issue (that I also discussed with Jack today) is how to initialize model parameters under SPMD in a multi-host setting.

In the xs.mark_sharding(model.parameters()[0], spec) example above, suppose we do the vanilla data parallelism where the model parameters should be replicated across the TPU mesh. When we have multiple hosts (e.g. v3-32), each host might initialize its parameter to different values if we follow the typical way of first initializing the parameters on the CPU and then casting them to XLA devices (unless we explicitly enforce the same torch random seed everywhere, which people typically don't do, and is often not possible e.g. because we need different data augmentation on each host). So we might end up marking the model parameters as replicated when they are actually not replicated (same shape but different values).

PyTorch DDP deals with this issue with a broadcast-based synchronization when initializing the DDP wrapper. I'm thinking we can implement a similar sync in our GSPMDParallel wrapper class or a sync function to do a parameter broadcasting in a multi-host SPMD setting, before actually starting the training.

yeounoh commented 2 years ago

@yeounoh Another relevant usage issue (that I also discussed with Jack today) is how to initialize model parameters under SPMD in a multi-host setting.

In the xs.mark_sharding(model.parameters()[0], spec) example above, suppose we do the vanilla data parallelism where the model parameters should be replicated across the TPU mesh. When we have multiple hosts (e.g. v3-32), each host might initialize its parameter to different values if we follow the typical way of first initializing the parameters on the CPU and then casting them to XLA devices (unless we explicitly enforce the same torch random seed everywhere, which people typically don't do, and is often not possible e.g. because we need different data augmentation on each host). So we might end up marking the model parameters as replicated when they are actually not replicated (same shape but different values).

PyTorch DDP deals with this issue with a broadcast-based synchronization when initializing the DDP wrapper. I'm thinking we can implement a similar sync in our GSPMDParallel wrapper class or a sync function to do a parameter broadcasting in a multi-host SPMD setting, before actually starting the training.

I see, I was hoping to handle this in PyToch/XLA layer, so that the user doesn't need to do anything but treat the multi-host case the same as a single replica but just with more cores. Let's resume the discussion when we start pod testing.

alanwaketan commented 1 year ago

This seems to be very similar to distribute_module in https://github.com/pytorch/pytorch/issues/88838.

ColdCodeCool commented 1 year ago

hi all, is there any update for this RFC? Has it been merged into the code base?

alanwaketan commented 1 year ago

hi all, is there any update for this RFC? Has it been merged into the code base?

No, it has not been merged to the code base yet. We are looking into integrating this API into distributed_module listed above.