Open amithrm opened 4 months ago
A simple example to get the conversation started and use to feature complete.
`
pipeline_cuts=['layers.4']
class SimpleLinear(nn.Module):
def __init__(self):
super(SimpleLinear, self).__init__()
self.fc1 = nn.Linear(FLAGS.input_dim, FLAGS.input_dim * 4, bias=False)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(FLAGS.input_dim * 4, FLAGS.input_dim, bias=False)
def forward(self, x):
y = self.relu(self.fc1(x))
z = self.fc2(y)
return z
class StackedLinear(SimpleLinear):
def __init__(self):
super(StackedLinear, self).__init__()
self.layers = nn.ModuleList([SimpleLinear() for _ in range(0, 10)])
def forward(self, x):
for i, l in enumerate(self.layers):
x = self.layers[i].forward(x)
return x
device = xm.xla_device()
def train():
num_epochs = 1
lr = 0.1
train_loader = xu.SampleGenerator(
data=(torch.randn(FLAGS.batch_size, 2, FLAGS.input_dim),
torch.randn(FLAGS.batch_size, 2, FLAGS.input_dim)),
sample_count=FLAGS.train_dataset_len // FLAGS.batch_size)
torch.manual_seed(42)
model = StackedLinear().to(device)
model = NxDPPModel(
model,
transformer_layer_cls=SimpleLinear,
num_microbatches=FLAGS.batch_size,
output_loss_value_spec=(True),
input_names=['x'],
pipeline_cuts=pipeline_cuts,
trace_file_path=None,
leaf_module_cls=None,
autowrap_modules=None,
use_zero1_optimizer=True,
)
num_devices = NUM_DEVICES
# Define a mesh with all devices along one axis
mesh_shape = (1, 32)
device_ids = np.arange(num_devices)
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
for l in model.layers:
# Shard the second layer's weights column-wise
xs.mark_sharding(l.fc1.weight, mesh, ('y', 'x'))
# Shard the first layer's weights row-wise
xs.mark_sharding(l.fc2.weight, mesh, ('x', 'y'))
optimizer = optim.SGD(model.parameters(), lr=lr)
def train_loop_fn(loader, epoch):
model.train()
for step, (data, target) in enumerate(loader):
with xp.StepTrace('train_linear_model'):
with xp.Trace('build_graph'):
data = data.to(device)
target = target.to(device)
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
if step % 10 == 0:
print(f"Epoch {epoch} step {step} loss {loss}")
for epoch in range(FLAGS.num_epochs):
train_loop_fn(train_loader, epoch)
return model
if FLAGS.profile:
server = xp.start_server(FLAGS.profiler_port)
print('Start training loop...')
m = train()
t = torch.randn(10, FLAGS.input_dim).to(device)
m(t).cpu()
`
Trying to make this work, hitting into a basic issue, creating a ticket for this: https://github.com/pytorch/xla/issues/6647
Thanks @amithrm +1 looking forward to pipelining using GSPMD!
🚀 The feature, motivation and pitch
Motivation
SPMD sharding in pytorch/XLA offers model parallelism by sharding tensors within an operator. However, we need a mechanism to integrate this capapability with pipeline parallelism for models that are large and cannot use SPMD sharding (using mark_sharding APIs) either for performance reasons or memory constraints.
Pitch
The high level idea is to integrate the pipeline parallel functionality of the existing package with GSPMD https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/pipeline_parallelism_overview.html As described in the docs, “In NeuronxDistributed, we use Pytorch’s FX to trace the model and do partition on the FX IR. User simply needs to specify where to cut the pipeline stages, and our algorithm will cut the pipeline stages and assign the corresponding modules to each Neuron core automatically.”
Alternatives
No response
Additional context
No response