pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
81.69k stars 21.92k forks source link

[DTensor] Support API to shard to parent mesh #116101

Open awgu opened 8 months ago

awgu commented 8 months ago

Context To compose per-parameter-sharding FSDP with DTensor-based tensor parallelism, we need to reshard an existing DTensor to its parent mesh and include the FSDP dim-0 sharding.

The current prototype does this manually. See here for the logic for deriving the global (i.e. parent) mesh metadata: https://github.com/pytorch/pytorch/blob/ceb5797909d789866ad67778fe67e3dedc780d11/torch/distributed/_composable/fsdp/_fsdp_param.py#L201 However, this manual logic is not general, as it assumes only up to 2D parallelism. Preferably, this logic can exist outside of FSDP.

Proposal One option is to directly use the DTensor constructor and run this resharding logic when the passed-in local_tensor is already a DTensor. This would check that the local_tensor: DTensor and the passed-in mesh share a parent mesh and perform the appropriate resharding.

cc @wanchaol @wz337 @XilunWu @tianyu-l

wanchaol commented 8 months ago

This could be done in the following inside distribute_tensor API, i.e. the following should work:

  1. distribute_tensor use to only accept tensor and produce DTensor, now it should also accept DTensor in some cases
  2. the specific case is where it got a DTensor, and we found that the device_mesh of this dtensor belongs to a parent mesh, then it's legit to produce a DTensor that lives in a parent mesh
  3. note that it would only be legit if the submesh DTensor lives in the right most device mesh dimension when the parent mesh placement is Shard
awgu commented 8 months ago

note that it would only be legit if the submesh DTensor lives in the right most device mesh dimension when the parent mesh placement is Shard

To check, is this satisfied for 2D (TP + FSDP)? (I think so since the mesh has shape (dp_size, tp_size), so the TP DTensor is on the rightmost mesh dim.)

Also, to check, for the 2D case with tensor already as a TP DTensor, would we pass a 1D FSDP mesh to device_mesh and (Shard(0),) to placements? (i.e., the caller does not need to derive the parent mesh and the parent placements?)

def distribute_tensor(
    tensor: torch.Tensor,
    device_mesh: Optional[DeviceMesh] = None,
    placements: Optional[Sequence[Placement]] = None,
) -> DTensor:
wanchaol commented 8 months ago

note that it would only be legit if the submesh DTensor lives in the right most device mesh dimension when the parent mesh placement is Shard

To check, is this satisfied for 2D (TP + FSDP)? (I think so since the mesh has shape (dp_size, tp_size), so the TP DTensor is on the rightmost mesh dim.)

Yep, for 2D it make sense because TP shard the first time, FSDP shard the second time, and the sharding happen from the inner most mesh dimension to outer, where the global DTensor sharding layout we created still be valid, for other cases like first shard in mesh-dim-0 then shard in mesh-dim-1 we should error out as it's not a valid global sharding layout

Also, to check, for the 2D case with tensor already as a TP DTensor, would we pass a 1D FSDP mesh to device_mesh and (Shard(0),) to placements? (i.e., the caller does not need to derive the parent mesh and the parent placements?)

def distribute_tensor(
    tensor: torch.Tensor,
    device_mesh: Optional[DeviceMesh] = None,
    placements: Optional[Sequence[Placement]] = None,
) -> DTensor:

yes exactly! for 2D case where it already got a TP DTensor, we can just pass 1D FSDP mesh as device_mesh and (Shard(0),) as placements. distribute_tensor have the logic to figure out whether this "further sharding" make sense or not, where in the case of TP + FSDP it make sense.