pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.38k stars 427 forks source link

Pipeline parallelism with SPMD #6646

Open amithrm opened 4 months ago

amithrm commented 4 months ago

🚀 The feature, motivation and pitch

Motivation

SPMD sharding in pytorch/XLA offers model parallelism by sharding tensors within an operator. However, we need a mechanism to integrate this capapability with pipeline parallelism for models that are large and cannot use SPMD sharding (using mark_sharding APIs) either for performance reasons or memory constraints.

Pitch

The high level idea is to integrate the pipeline parallel functionality of the existing package with GSPMD https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/pipeline_parallelism_overview.html As described in the docs, “In NeuronxDistributed, we use Pytorch’s FX to trace the model and do partition on the FX IR. User simply needs to specify where to cut the pipeline stages, and our algorithm will cut the pipeline stages and assign the corresponding modules to each Neuron core automatically.”

Alternatives

No response

Additional context

No response

amithrm commented 4 months ago

A simple example to get the conversation started and use to feature complete.

`

pipeline_cuts=['layers.4']

class SimpleLinear(nn.Module):

  def __init__(self):
    super(SimpleLinear, self).__init__()
    self.fc1 = nn.Linear(FLAGS.input_dim, FLAGS.input_dim * 4, bias=False)
    self.relu = nn.ReLU()
    self.fc2 = nn.Linear(FLAGS.input_dim * 4, FLAGS.input_dim, bias=False)

  def forward(self, x):
    y = self.relu(self.fc1(x))
    z = self.fc2(y)
    return z

class StackedLinear(SimpleLinear):

  def __init__(self):
    super(StackedLinear, self).__init__()
    self.layers = nn.ModuleList([SimpleLinear() for _ in range(0, 10)])

  def forward(self, x):
    for i, l in enumerate(self.layers):
       x = self.layers[i].forward(x)
    return x

device = xm.xla_device()

def train():
  num_epochs = 1
  lr = 0.1
  train_loader = xu.SampleGenerator(
      data=(torch.randn(FLAGS.batch_size, 2, FLAGS.input_dim),
            torch.randn(FLAGS.batch_size, 2, FLAGS.input_dim)),
      sample_count=FLAGS.train_dataset_len // FLAGS.batch_size)
  torch.manual_seed(42)
  model = StackedLinear().to(device)

  model = NxDPPModel(
    model,
    transformer_layer_cls=SimpleLinear,
    num_microbatches=FLAGS.batch_size,
    output_loss_value_spec=(True),
    input_names=['x'],
    pipeline_cuts=pipeline_cuts,
    trace_file_path=None,
    leaf_module_cls=None,
    autowrap_modules=None,
    use_zero1_optimizer=True,
   )

  num_devices = NUM_DEVICES
  # Define a mesh with all devices along one axis
  mesh_shape = (1, 32)

  device_ids = np.arange(num_devices)
  mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

  for l in model.layers:
    # Shard the second layer's weights column-wise
    xs.mark_sharding(l.fc1.weight, mesh, ('y', 'x'))
    # Shard the first layer's weights row-wise
    xs.mark_sharding(l.fc2.weight, mesh, ('x', 'y'))

  optimizer = optim.SGD(model.parameters(), lr=lr)

def train_loop_fn(loader, epoch):
    model.train()
    for step, (data, target) in enumerate(loader):
      with xp.StepTrace('train_linear_model'):
        with xp.Trace('build_graph'):
          data = data.to(device)
          target = target.to(device)
          optimizer.zero_grad()
          output = model(data)
          loss = loss_fn(output, target)
          loss.backward()
        optimizer.step()
      xm.mark_step()
      if step % 10 == 0:
        print(f"Epoch {epoch} step {step} loss {loss}")

  for epoch in range(FLAGS.num_epochs):
    train_loop_fn(train_loader, epoch)

  return model

if FLAGS.profile:
  server = xp.start_server(FLAGS.profiler_port)

print('Start training loop...')
m = train()
t = torch.randn(10, FLAGS.input_dim).to(device)
m(t).cpu()

`

amithrm commented 4 months ago

Trying to make this work, hitting into a basic issue, creating a ticket for this: https://github.com/pytorch/xla/issues/6647

yeounoh commented 4 months ago

Thanks @amithrm +1 looking forward to pipelining using GSPMD!