it raises 2 kinds of errors (first from 5 process, the last from 1 process):
RuntimeError: a and b must have same reduction dim, but got [9, 18] X [16, 10].
RuntimeError: a and b must have same reduction dim, but got [9, 6] X [16, 10].
The issue is the sharding of the input tensor of inp_size = [9, 16].
It results (correctly) in a sharding of [3, 3, 3, 3, 3, 1] but at some point the code assumes even sharding leading to 3 * 6 = 18 and 1 * 6 = 6 as the input size to the linear layer.
From the sizes it is clear that the first test works only for 2,4,8,16 GPUs and the second only for 2,3,4,6,12 GPUs so both together only work for 2 or 4 GPUs
--> The test should run ONLY with a world_size of 4 or the implementation fixed if it is intended to work
🐛 Describe the bug
I have noticed
test_linear_row_wise_parallel
fails when run with 6 GPUshttps://github.com/pytorch/pytorch/blob/f2f7ef9d5908f53a5fd7991ed0a1ef99069813a1/test/distributed/tensor/parallel/test_parallelize_api.py#L137
it raises 2 kinds of errors (first from 5 process, the last from 1 process):
The issue is the sharding of the input tensor of
inp_size = [9, 16]
.It results (correctly) in a sharding of
[3, 3, 3, 3, 3, 1]
but at some point the code assumes even sharding leading to3 * 6 = 18
and1 * 6 = 6
as the input size to the linear layer.This can be reproduced with this code:
Versions
Note that this could "just" be a design issue of the test:
16
12
It uses ALL GPUs if there is an even number of GPUs, else only 4: https://github.com/pytorch/pytorch/blob/f2f7ef9d5908f53a5fd7991ed0a1ef99069813a1/test/distributed/tensor/parallel/test_parallelize_api.py#L35
From the sizes it is clear that the first test works only for 2,4,8,16 GPUs and the second only for 2,3,4,6,12 GPUs so both together only work for 2 or 4 GPUs
--> The test should run ONLY with a world_size of 4 or the implementation fixed if it is intended to work
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o