microsoft / DeepSpeedExamples

Example models using DeepSpeed
Apache License 2.0
6.1k stars 1.04k forks source link

Does DeepSpeed's Pipeline-Parallelism optimizer supports skip connections? #932

Open RoyMahlab opened 1 month ago

RoyMahlab commented 1 month ago

In your example you convert the AlexNet into a list of layers:

def join_layers(vision_model):

    layers = [
        *vision_model.features,
        vision_model.avgpool,
        lambda x: torch.flatten(x, 1),
        *vision_model.classifier,
    ]
    return layers

which is later inserted to PipelineModule

net = AlexNet(num_classes=10)
net = PipelineModule(layers=join_layers(net),
                     loss_fn=torch.nn.CrossEntropyLoss(),
                     num_stages=args.pipeline_parallel_size,
                     partition_method=part,
                     activation_checkpoint_interval=0)

This seems to run-over the forward module that you built in your AlexNet module, which makes me wonder about the possibility of having skip-connections in my module while using DeepSpeed's Pipeline-Parallelism optimizer.

Many thanks!