pytorch / PiPPy

Pipeline Parallelism for PyTorch
BSD 3-Clause "New" or "Revised" License
725 stars 86 forks source link

Share DeviceMesh between PiPPy and SPMD #283

Open jamesr66a opened 2 years ago

jamesr66a commented 2 years ago

I see there is a DeviceMesh abstraction in spmd: https://github.com/pytorch/PiPPy/blob/main/spmd/tensor/device_mesh.py

Can we use this abstraction as shared infrastructure? For example, PipelineDriver.init_data_parallel[https://github.com/pytorch/PiPPy/blob/877eb8c675dd0e34731961c043f8ae2cc1e49a77/pippy/PipelineDriver.py#L461] rolls a lot of these concepts by hand; can we use DeviceMesh here?

cc @wanchaol @aazzolini @pbelevich @kwen2501

aazzolini commented 2 years ago

Would it be possible to separate the concept of pipeline driver from the concept of mesh such that we can have a pipeline driver that doesn't depend on it?

aazzolini commented 2 years ago

It would be interstinig to see a diagram of how we would like to structure this -- I believe we mentioned that Pippy can exist as a separate thing from SPMD , right? but there's a layer where we need to bring this together. Is PipelineRunner this layer where we need to bring both in?

wanchaol commented 2 years ago

Correct me if I'm wrong. I assume we want PiPPy to use DeviceMesh for collective communications only within stages? because cross stage is using RPC? there's these n_stages, dp_group_size, dp_pg_cb=None we need to understand more why these are there, and how to construct a DeviceMesh base on those arguments? yeah it looks a diagram might be helpful

jamesr66a commented 2 years ago

Would it be possible to separate the concept of pipeline driver from the concept of mesh such that we can have a pipeline driver that doesn't depend on it?

No

It would be interstinig to see a diagram of how we would like to structure this -- I believe we mentioned that Pippy can exist as a separate thing from SPMD , right? but there's a layer where we need to bring this together. Is PipelineRunner this layer where we need to bring both in?

From the PiPPy readme

image

PP, TP, and DP compose together by handling concerns across different dimensions of the device grid. e.g. there is a pipeline dimension across which there are MPMD stages, and all of the other dimensions of the device grid are handled by the SPMD implementations within each stage

Correct me if I'm wrong. I assume we want PiPPy to use DeviceMesh for collective communications only within stages? because cross stage is using RPC?

No, cross-stage is a dimension of the device mesh similarly to SPMD ranks residing along dimensions of the device mesh

aazzolini commented 2 years ago

@jamesr66a I believe there's confusion about whether each pipeline stage would have its own DeviceMesh vs. pipeline being a dimension of the mesh. The latter is more restrictive -- have we decided that we're okay with that limitation?