Jittor / jittor

Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators.
https://cg.cs.tsinghua.edu.cn/jittor/
Apache License 2.0
3.08k stars 311 forks source link

jittor mpi 同步和 torch 不太一致 #287

Open yjqiang opened 2 years ago

yjqiang commented 2 years ago
import jittor as jt

NUM_BATCHES = 10  # 如果有多个 gpu,尽量可以正好分开
BATCH_SIZE = 3
NUM_CLASSES = 3

def function_jittor():
    jt.misc.set_global_seed(seed=986, different_seed_for_mpi=False)  # 不同的 device 同步数据,不然数据不一致,肯定计算错误
    jittor_mpi_core = jt.mpi

    x = jt.randint(low=0, high=NUM_CLASSES, shape=(NUM_BATCHES, BATCH_SIZE))

    local_rank = jittor_mpi_core.local_rank()
    world_size = jittor_mpi_core.world_size()

    num_batches = x.shape[0]
    my_result = jt.array(0)
    for i in range(local_rank, num_batches, world_size):
        my_result += x[i].sum()
        print(f'local_rank={local_rank} is running i={i}: my_result={my_result}')
        result = my_result.mpi_all_reduce('add')

        if local_rank == 0 or False:
            expected_result = x[: i+world_size].sum()
            print(f'local_rank={local_rank} is checking i={i}:', result, expected_result)

if __name__ == '__main__':
    function_jittor()
yjqiang commented 2 years ago
import os
import subprocess
import sys
import __main__
from typing import List

import torch.distributed as dist
import torch.utils.data
import torch

NUM_BATCHES = 10  # 如果有多个 gpu,尽量可以正好分开
BATCH_SIZE = 3
NUM_CLASSES = 3

WORLD_SIZE = 2
torch.set_printoptions(linewidth=200)

class Environment:
    """
    每个 node 都有 num_processes 个 process,一共有self.world_size 个 process
    """
    def __init__(self):
        self._global_rank: int = 0
        self._world_size: int = 1

    @property
    def creates_processes_externally(self) -> bool:
        """Returns whether the cluster creates the processes or not.
        If at least :code:`LOCAL_RANK` is available as environment variable, Lightning assumes the user acts as the
        process launcher/job scheduler and Lightning will not launch new processes.
        """
        return "LOCAL_RANK" in os.environ

    def world_size(self) -> int:
        """The number of processes across all devices and nodes."""
        return self._world_size

    def set_world_size(self, size: int) -> None:
        self._world_size = size

    def local_rank(self) -> int:
        """The rank (index) of the currently running process inside of the current node."""
        return int(os.environ.get("LOCAL_RANK", 0))

    def global_rank(self) -> int:
        """The rank (index) of the currently running process across all nodes and devices."""
        return self._global_rank

    def set_global_rank(self, rank: int) -> None:
        self._global_rank = rank

class DDP:
    def __init__(self, parallel_devices: List[torch.device], environment: Environment):
        """
        每个 node 都有 num_processes 个 process
        :param parallel_devices:
        :param environment:
        """
        self.interactive_ddp_procs = []
        self.parallel_devices = parallel_devices
        self.environment = environment
        self.num_processes = len(self.parallel_devices)
        self.num_nodes = 1
        self.node_rank = 0

    def setup_environment(self) -> None:
        if not self.environment.creates_processes_externally:
            self.call_children_scripts()

        self.setup_distributed()

    def call_children_scripts(self) -> None:
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = '29500'

        # 找到 是怎么启动的本程序(再用同样的方法启动剩下 WORLD_SIZE-1 个相同进程)
        if __main__.__spec__ is None:  # pragma: no-cover
            # Script called as `python a/b/c.py`
            # when user is using hydra find the absolute path
            path_lib = os.path.abspath

            # pull out the commands used to run the script and resolve the abs file path
            command = sys.argv
            try:
                full_path = path_lib(command[0])
            except Exception:
                full_path = os.path.abspath(command[0])

            command[0] = full_path
            # use the same python interpreter and actually running
            command = [sys.executable] + command
        else:  # Script called as `python -m a.b.c`
            command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:]

        for rank in range(1, self.num_processes):
            env_copy = os.environ.copy()
            env_copy["LOCAL_RANK"] = f"{rank}"
            proc = subprocess.Popen(command, env=env_copy)
            self.interactive_ddp_procs.append(proc)
            # delay = np.random.uniform(1, 5, 1)[0]
            # sleep(delay)

    def setup_distributed(self) -> None:
        self.environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
        self.environment.set_world_size(self.num_nodes * self.num_processes)

        global_rank = self.environment.global_rank()
        world_size = self.environment.world_size()
        dist.init_process_group(backend="nccl", world_size=world_size, rank=global_rank)

    @property
    def local_rank(self) -> int:
        return self.environment.local_rank()

    @property
    def root_device(self) -> torch.device:
        return self.parallel_devices[self.local_rank]

def all_gather(tensor: torch.Tensor) -> torch.Tensor:
    tensors = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors, tensor)
    return torch.stack(tensors, dim=0)

def main():
    os.environ['CUDA_VISIBLE_DEVICES'] = "0, 1"

    gpu_ids = [0, 1]
    devices = [torch.device("cuda", i) for i in gpu_ids]
    environment = Environment()

    plugin = DDP(parallel_devices=devices, environment=environment)

    plugin.setup_environment()

    # 创建 DDP 模型进行分布式训练
    torch.cuda.set_device(plugin.environment.local_rank())

    seed = 998
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    x = torch.randint(NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE))
    x = x.to(plugin.root_device)

    local_rank = environment.local_rank()
    world_size = environment.world_size()

    num_batches = len(x)
    my_result = torch.Tensor([0]).sum().to(device=plugin.root_device)
    for i in range(local_rank, num_batches, world_size):
        my_result += x[i].sum()
        print(f'local_rank={local_rank} is running i={i}: my_result={my_result}')

        gathered_result = [torch.zeros_like(my_result) for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(gathered_result, my_result)
        result = torch.stack(gathered_result).sum()

        if local_rank == 0:
            expected_result = x[: i + world_size].sum()
            print(f'local_rank={local_rank} is checking i={i}:', result, expected_result)

if __name__ == "__main__":
    main()
yjqiang commented 2 years ago

jittor 的版本跑不起来

cjld commented 2 years ago

您好,您这边的报错信息是怎样的呢