pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.47k stars 22.75k forks source link

Rounding issue of tensor sizes when sharding #141335

Open Flamefire opened 4 days ago

Flamefire commented 4 days ago

🐛 Describe the bug

I have noticed test_linear_row_wise_parallel fails when run with 6 GPUs

https://github.com/pytorch/pytorch/blob/f2f7ef9d5908f53a5fd7991ed0a1ef99069813a1/test/distributed/tensor/parallel/test_parallelize_api.py#L137

it raises 2 kinds of errors (first from 5 process, the last from 1 process):

RuntimeError: a and b must have same reduction dim, but got [9, 18] X [16, 10].
RuntimeError: a and b must have same reduction dim, but got [9, 6] X [16, 10].

The issue is the sharding of the input tensor of inp_size = [9, 16].

It results (correctly) in a sharding of [3, 3, 3, 3, 3, 1] but at some point the code assumes even sharding leading to 3 * 6 = 18 and 1 * 6 = 6 as the input size to the linear layer.

This can be reproduced with this code:

import torch
import torch.distributed as dist
from torch.distributed._tensor import DeviceMesh
from torch.distributed.tensor.parallel.api import parallelize_module
from torch.distributed.tensor.parallel.style import RowwiseParallel
import tempfile

WORLD_SIZE = 6

def run(rank, file_name):
    dist.init_process_group(
        backend="gloo",
        world_size=WORLD_SIZE,
        rank=rank,
        init_method=f"file://{file_name}",
    )

    inp_size = [9, 8*3]

    model = torch.nn.Linear(inp_size[-1], 10)
    device_mesh = DeviceMesh("cpu", list(range(WORLD_SIZE)))
    model = parallelize_module(model, device_mesh, RowwiseParallel())

    inp = torch.rand(*inp_size)
    inp = inp.chunk(WORLD_SIZE, dim=-1)[rank]
    model(inp)

    dist.barrier()
    dist.destroy_process_group()

if __name__ == '__main__':
    file_name = tempfile.NamedTemporaryFile(delete=False).name
    torch.multiprocessing.spawn(run, args=(file_name,), nprocs=WORLD_SIZE, join=True, daemon=False)

Versions

PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Linux Mint 21.3 (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov  6 2024, 20:22:13) [GCC 11.4.0] (64-bit runtime)

Note that this could "just" be a design issue of the test:

It uses ALL GPUs if there is an even number of GPUs, else only 4: https://github.com/pytorch/pytorch/blob/f2f7ef9d5908f53a5fd7991ed0a1ef99069813a1/test/distributed/tensor/parallel/test_parallelize_api.py#L35

From the sizes it is clear that the first test works only for 2,4,8,16 GPUs and the second only for 2,3,4,6,12 GPUs so both together only work for 2 or 4 GPUs

--> The test should run ONLY with a world_size of 4 or the implementation fixed if it is intended to work

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

tianyu-l commented 12 hours ago

cc: @wz337 on uneven sharding