pytorch / PiPPy

Pipeline Parallelism for PyTorch
BSD 3-Clause "New" or "Revised" License
718 stars 86 forks source link

[spmd] incorrect aten.expand call with nn.linear (expanded size must match existing size at dim 0) #632

Open lessw2020 opened 1 year ago

lessw2020 commented 1 year ago

This is to track/investigate the issue reported by Rich Zhu, where using permute to generate a transposed tensor for nn.linear, results in an incorrect aten.expand call.

I've found two potential issues - a - the original dtensor placement spec is incorrect relative to the device mesh. (placement is [Shard(dim=0)], even though device mesh is 2 gpus, or [1,2] shape.
(note, this issue was not detected when Rich ran it, though this was run inside a testing environ).

b - correcting this by updating the placement to [Shard(dim=0), Shard(dim=0)] in view ops, reshape prop to match the mesh, then results in the target_schema.args_schema[1:] in dispatch.py suggesting it do an expansion of a [2,40,10] into a [1,40,10] which fails when the operation is called. RuntimeError: The expanded size of the tensor (1) must match the existing size (2) at non-singleton dimension 0. Target sizes: [1, 40, 10]. Tensor sizes: [2, 40, 10]

The permute operation and weight tensor replication all seem to take place with no issue.

Repro steps: Simple model:

class Permute(torch.nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(5)
        self.w = torch.nn.Parameter(torch.rand((5, 10)))
        self.b = torch.nn.Parameter(torch.rand((5)))
    def forward(self, x):
        x_t = x.permute(0, 2, 1)
        return torch.nn.functional.linear(x_t, self.w, self.b)

model = Permute().to(rank) # 
spmd = SPMD(
        deepcopy(model),
        schema=Schema(
            mesh=DeviceMesh(
                _device_type, gpu_placement #torch.arange(world_size)
            ),
            placements=[Replicate()],
        ),
    )

    x = torch.randn(2, 10, 40).to(rank)

spmd(x).sum().backward()

1 - Running the above using 2 gpus results in the following error:

view ops 506: 
 in_shard=[Shard(dim=0)], mesh_sizes=(1, 2), local_in_shape=(2, 10, 40)

Traceback (most recent call last):
  File "/home/ubuntu/feature_fusion/spmd/tensor/dispatch.py", line 222, in propagate_input_sharding
    output_sharding = sharding_prop_func(op_schema)
  File "/home/ubuntu/feature_fusion/spmd/tensor/ops/view_ops.py", line 665, in reshape_prop
    ) = propagate_shape_and_sharding(
  File "/home/ubuntu/feature_fusion/spmd/tensor/ops/view_ops.py", line 508, in propagate_shape_and_sharding
    assert len(in_shard) == len(mesh_sizes)
AssertionError

To resolve this, I added a patch to upgrade the placements sequence to

input dtensor spec [Shard(dim=0), Shard(dim=0)]

2 - With the placement updated (possibly incorrectly, as the root issue may be why it's not accounting for the the proper placement), we move to the expansion issue. The weights are replicated, and the input tensor is properly permuted from 2,10,40 to 2, 40,10.

We then get to the expansion issue:

INFO:spmd.compiler.api:node5: op=call_function target=aten.expand.default

 dispatch.py 262, args = torch.Size([2, 40, 10])
([2, 40, 10],)

dispatch.py, 267: op_call=<OpOverload(op='aten.expand', overload='default')>
args[1:]=([2, 40, 10],)

and then :

==> dispatch 182:  
args_schema=(DTensorSpec(mesh=DeviceMesh:([[0, 1]]), placements=[Shard(dim=0), Shard(dim=0)], shape=torch.Size([2, 40, 10]), ndim=3), [2, 40, 10])
,op_schema.args = op_schema.args_schema=(DTensorSpec(mesh=DeviceMesh:([[0, 1]]), placements=[Shard(dim=0), Shard(dim=0)], shape=torch.Size([2, 40, 10]), ndim=3), [2, 40, 10])

,func_schema=FunctionSchema(name=OperatorName(name=BaseOperatorName(base='aten::expand', inplace=False, dunder_method=False, functional_overload=False), overload_name=''), arguments=Arguments(pre_self_positional=(), self_arg=SelfArgument(argument=Argument(name='self', type=BaseType(name=<BaseTy.Tensor: 3>), default=None, annotation=Annotation(alias_set=('a',), is_write=False, alias_set_after=()))), post_self_positional=(Argument(name='size', type=ListType(elem=BaseType(name=<BaseTy.SymInt: 17>), size=None), default=None, annotation=None),), pre_tensor_options_kwarg_only=(Argument(name='implicit', type=BaseType(name=<BaseTy.bool: 9>), default='False', annotation=None),), tensor_options=None, post_tensor_options_kwarg_only=(), out=()), returns=(Return(name=None, type=BaseType(name=<BaseTy.Tensor: 3>), annotation=Annotation(alias_set=('a',), is_write=False, alias_set_after=())),))

disp 197: ===> sharding prop func = <function register_prop_rule_map.<locals>.reshape_prop at 0x7efbdecff700> from op_key='aten.expand.default'

 ----> op_schema=OpSchema(func_schema=aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a), args_schema=(DTensorSpec(mesh=DeviceMesh:([[0, 1]]), placements=[Shard(dim=0), Shard(dim=0)], shape=torch.Size([2, 40, 10]), ndim=3), [2, 40, 10]), kwargs_schema={})

in dispatch.py:

target_schema, redistribute, output_sharding = propagate_input_sharding(
        op_call, args, kwargs, op_to_rules
    )

Ultimately, the expansion to the half batch size comes from: dispatch.py, 222:

output_sharding = sharding_prop_func(op_schema)

as we have the op_schema changing silently from:

dispatch 221: ----> op_schema=OpSchema(func_schema=aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a), args_schema=(DTensorSpec(mesh=DeviceMesh:([[0, 1]]), placements=[Shard(dim=0), Shard(dim=0)], shape=torch.Size([2, 40, 10]), ndim=3), [2, 40, 10]), kwargs_schema={})

to:

disp 225 sharding prop func after... op_schema=OpSchema(func_schema=aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a), args_schema=(DTensorSpec(mesh=DeviceMesh:([[0, 1]]), placements=[Shard(dim=0), Shard(dim=0)], shape=torch.Size([2, 40, 10]), ndim=3), (1, 40, 10)), kwargs_schema={})

and then the aten. expand is called, resulting in the error.

Thus several questions: 1 - why the placement at least for me is incorrect 2 - once the placement is modified, why is the expansion being done to expand in reverse the tensor, and hence the issue.

aazzolini commented 1 year ago

Thanks @lessw2020 for the investigation. I looked into it further, and it turns out that the issue is:

It's confusing and I will try to expplain better tomorrow but i'm devising a hack with a fix.