Open yjqiang opened 2 years ago
import os
import subprocess
import sys
import __main__
from typing import List
import torch.distributed as dist
import torch.utils.data
import torch
NUM_BATCHES = 10 # 如果有多个 gpu,尽量可以正好分开
BATCH_SIZE = 3
NUM_CLASSES = 3
WORLD_SIZE = 2
torch.set_printoptions(linewidth=200)
class Environment:
"""
每个 node 都有 num_processes 个 process,一共有self.world_size 个 process
"""
def __init__(self):
self._global_rank: int = 0
self._world_size: int = 1
@property
def creates_processes_externally(self) -> bool:
"""Returns whether the cluster creates the processes or not.
If at least :code:`LOCAL_RANK` is available as environment variable, Lightning assumes the user acts as the
process launcher/job scheduler and Lightning will not launch new processes.
"""
return "LOCAL_RANK" in os.environ
def world_size(self) -> int:
"""The number of processes across all devices and nodes."""
return self._world_size
def set_world_size(self, size: int) -> None:
self._world_size = size
def local_rank(self) -> int:
"""The rank (index) of the currently running process inside of the current node."""
return int(os.environ.get("LOCAL_RANK", 0))
def global_rank(self) -> int:
"""The rank (index) of the currently running process across all nodes and devices."""
return self._global_rank
def set_global_rank(self, rank: int) -> None:
self._global_rank = rank
class DDP:
def __init__(self, parallel_devices: List[torch.device], environment: Environment):
"""
每个 node 都有 num_processes 个 process
:param parallel_devices:
:param environment:
"""
self.interactive_ddp_procs = []
self.parallel_devices = parallel_devices
self.environment = environment
self.num_processes = len(self.parallel_devices)
self.num_nodes = 1
self.node_rank = 0
def setup_environment(self) -> None:
if not self.environment.creates_processes_externally:
self.call_children_scripts()
self.setup_distributed()
def call_children_scripts(self) -> None:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
# 找到 是怎么启动的本程序(再用同样的方法启动剩下 WORLD_SIZE-1 个相同进程)
if __main__.__spec__ is None: # pragma: no-cover
# Script called as `python a/b/c.py`
# when user is using hydra find the absolute path
path_lib = os.path.abspath
# pull out the commands used to run the script and resolve the abs file path
command = sys.argv
try:
full_path = path_lib(command[0])
except Exception:
full_path = os.path.abspath(command[0])
command[0] = full_path
# use the same python interpreter and actually running
command = [sys.executable] + command
else: # Script called as `python -m a.b.c`
command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:]
for rank in range(1, self.num_processes):
env_copy = os.environ.copy()
env_copy["LOCAL_RANK"] = f"{rank}"
proc = subprocess.Popen(command, env=env_copy)
self.interactive_ddp_procs.append(proc)
# delay = np.random.uniform(1, 5, 1)[0]
# sleep(delay)
def setup_distributed(self) -> None:
self.environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.environment.set_world_size(self.num_nodes * self.num_processes)
global_rank = self.environment.global_rank()
world_size = self.environment.world_size()
dist.init_process_group(backend="nccl", world_size=world_size, rank=global_rank)
@property
def local_rank(self) -> int:
return self.environment.local_rank()
@property
def root_device(self) -> torch.device:
return self.parallel_devices[self.local_rank]
def all_gather(tensor: torch.Tensor) -> torch.Tensor:
tensors = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors, tensor)
return torch.stack(tensors, dim=0)
def main():
os.environ['CUDA_VISIBLE_DEVICES'] = "0, 1"
gpu_ids = [0, 1]
devices = [torch.device("cuda", i) for i in gpu_ids]
environment = Environment()
plugin = DDP(parallel_devices=devices, environment=environment)
plugin.setup_environment()
# 创建 DDP 模型进行分布式训练
torch.cuda.set_device(plugin.environment.local_rank())
seed = 998
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
x = torch.randint(NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE))
x = x.to(plugin.root_device)
local_rank = environment.local_rank()
world_size = environment.world_size()
num_batches = len(x)
my_result = torch.Tensor([0]).sum().to(device=plugin.root_device)
for i in range(local_rank, num_batches, world_size):
my_result += x[i].sum()
print(f'local_rank={local_rank} is running i={i}: my_result={my_result}')
gathered_result = [torch.zeros_like(my_result) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(gathered_result, my_result)
result = torch.stack(gathered_result).sum()
if local_rank == 0:
expected_result = x[: i + world_size].sum()
print(f'local_rank={local_rank} is checking i={i}:', result, expected_result)
if __name__ == "__main__":
main()
jittor 的版本跑不起来
您好,您这边的报错信息是怎样的呢