Closed jasperzhong closed 2 years ago
看上去torch.distributed
主要分成两个components:
DDP可以简单看下他们的design. 不需花太多时间. 重点看下RPC和torchelastic.
https://pytorch.org/docs/stable/notes/ddp.html
没啥特别的. gradients会分组,叫做bucket,默认大小是25MB. bucket里面所有gradient都计算好后,开始fire异步allreduce. 所有buckets都fire后,阻塞等待所有allreduce操作结束. 和Horovod的设计没有什么两样.
实现上,主要是两块:
nn.parallel.DistributedDataParallel
module. 调用了很多C++接口. 居然还发了一篇VLDB '20 http://www.vldb.org/pvldb/vol13/p3005-li.pdf
https://pytorch.org/docs/master/rpc.html
主要components:
https://pytorch.org/docs/master/rpc/rref.html
分两种RRef
:
OwnerRRef
: 只能有一个实例,包含着real data. 有一个global id (RRefId
)UserRRef
: 可以有很多,但并不hold data. 创建UserRRef
有三种情况:
设计上要保证两点:
UserRRef
is deleted.关于G2文档解释很好.
In cases 2 and 3, it is possible that the owner has only partial or no knowledge at all about the RRef fork graph. For example, an RRef could be constructed on a user, and before the owner receives any RPC call, the creator user might have already shared the RRef with other users, and those users could further share the RRef. One invariant is that the fork graph of any RRef is always a tree, because forking an RRef always creates a new UserRRef instance on the callee (except if the callee is the owner), and hence every RRef has a single parent.
确实,这种Owner/Borrower模式产生的就是一个tree. 这和fork tree非常像! 所以这里把创建RRef称作"Forking an RRef". 注意这个tree的root不一定是OwnerRRef
,完全可以是UserRRef
,这取决于程序调用结构.
关于为什么需要G2, 可以看文档里面的讨论, 很受启发. 在这里不赘述.
看个例子.
import torch
import torch.distributed.rpc as rpc
# on worker A and worker C
def func(rref):
pass
# on worker A
rref = rpc.remote('B', torch.add, args=(torch.ones(2), 1))
# say the rref has RRefId 100 and ForkId 1
rpc.rpc_async('C', func, args=(rref, ))
其中{100, 1} 分别是 RRefId
和ForkId
.
https://pytorch.org/docs/master/rpc/distributed_autograd.html
看下这个代码. 确实很灵活.
import torch
import torch.multiprocessing as mp
import torch.distributed.autograd as dist_autograd
from torch.distributed import rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer
def random_tensor():
return torch.rand((3, 3), requires_grad=True)
def _run_process(rank, dst_rank, world_size):
name = "worker{}".format(rank)
dst_name = "worker{}".format(dst_rank)
# Initialize RPC.
rpc.init_rpc(
name=name,
rank=rank,
world_size=world_size
)
# Use a distributed autograd context.
with dist_autograd.context() as context_id:
# Forward pass (create references on remote nodes).
rref1 = rpc.remote(dst_name, random_tensor)
rref2 = rpc.remote(dst_name, random_tensor)
loss = rref1.to_here() + rref2.to_here()
# Backward pass (run distributed autograd).
dist_autograd.backward(context_id, [loss.sum()])
# Build DistributedOptimizer.
dist_optim = DistributedOptimizer(
optim.SGD,
[rref1, rref2],
lr=0.05,
)
# Run the distributed optimizer step.
dist_optim.step(context_id)
def run_process(rank, world_size):
dst_rank = (rank + 1) % world_size
_run_process(rank, dst_rank, world_size)
rpc.shutdown()
if __name__ == '__main__':
# Run world_size workers
world_size = 2
mp.spawn(run_process, args=(world_size,), nprocs=world_size)
先到此为止. 我感觉RPC主要还是用来做控制.
依靠checkpoint恢复状态. 脚本的逻辑最好是一开始就load checkpoint.
def main():
load_checkpoint(checkpoint_path)
initialize()
train()
def train():
for batch in iter(dataset):
train_step(batch)
if should_checkpoint:
save_checkpoint(checkpoint_path)
这是一个load_checkpoint函数的例子. 写的不错.
https://github.com/pytorch/elastic/blob/master/examples/imagenet/main.py#L315-L391
def load_checkpoint(
checkpoint_file: str,
device_id: int,
arch: str,
model: DistributedDataParallel,
optimizer, # SGD
) -> State:
"""
Loads a local checkpoint (if any). Otherwise, checks to see if any of
the neighbors have a non-zero state. If so, restore the state
from the rank that has the most up-to-date checkpoint.
.. note:: when your job has access to a globally visible persistent storage
(e.g. nfs mount, S3) you can simply have all workers load
from the most recent checkpoint from such storage. Since this
example is expected to run on vanilla hosts (with no shared
storage) the checkpoints are written to local disk, hence
we have the extra logic to broadcast the checkpoint from a
surviving node.
"""
state = State(arch, model, optimizer)
if os.path.isfile(checkpoint_file):
print(f"=> loading checkpoint file: {checkpoint_file}")
state.load(checkpoint_file, device_id)
print(f"=> loaded checkpoint file: {checkpoint_file}")
# logic below is unnecessary when the checkpoint is visible on all nodes!
# create a temporary cpu pg to broadcast most up-to-date checkpoint
with tmp_process_group(backend="gloo") as pg:
rank = dist.get_rank(group=pg)
# get rank that has the largest state.epoch
epochs = torch.zeros(dist.get_world_size(), dtype=torch.int32)
epochs[rank] = state.epoch
dist.all_reduce(epochs, op=dist.ReduceOp.SUM, group=pg)
t_max_epoch, t_max_rank = torch.max(epochs, dim=0)
max_epoch = t_max_epoch.item()
max_rank = t_max_rank.item()
# max_epoch == -1 means no one has checkpointed return base state
if max_epoch == -1:
print(f"=> no workers have checkpoints, starting from epoch 0")
return state
# broadcast the state from max_rank (which has the most up-to-date state)
# pickle the snapshot, convert it into a byte-blob tensor
# then broadcast it, unpickle it and apply the snapshot
print(f"=> using checkpoint from rank: {max_rank}, max_epoch: {max_epoch}")
with io.BytesIO() as f:
torch.save(state.capture_snapshot(), f)
raw_blob = numpy.frombuffer(f.getvalue(), dtype=numpy.uint8)
blob_len = torch.tensor(len(raw_blob))
dist.broadcast(blob_len, src=max_rank, group=pg)
print(f"=> checkpoint broadcast size is: {blob_len}")
if rank != max_rank:
blob = torch.zeros(blob_len.item(), dtype=torch.uint8)
else:
blob = torch.as_tensor(raw_blob, dtype=torch.uint8)
dist.broadcast(blob, src=max_rank, group=pg)
print(f"=> done broadcasting checkpoint")
if rank != max_rank:
with io.BytesIO(blob.numpy()) as f:
snapshot = torch.load(f)
state.apply_snapshot(snapshot, device_id)
# wait till everyone has loaded the checkpoint
dist.barrier(group=pg)
print(f"=> done restoring from previous checkpoint")
return state
PyTorch v1.9开始原生支持elastic功能,已将torchelastic项目整合到上游.
API上的变化是推荐使用torch.distributed.run
. 从功能上看,torch.distributed.run
是原来torch.distributed.launch
的超集,提供了容错和弹性伸缩功能.
现在推荐的使用方式是:
>>> python -m torch.distributed.run
--nnodes=$NUM_NODES
--nproc_per_node=$NUM_TRAINERS
--rdzv_id=$JOB_ID
--rdzv_backend=c10d
--rdzv_endpoint=$HOST_NODE_ADDR
YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
有三个新参数,文档的解释是:
--rdzv_id
: A unique job id (shared by all nodes participating in the job)--rdzv_backend
: An implementation of torch.distributed.elastic.rendezvous.RendezvousHandler
. (--rdzv_backend
默认是static,即不支持容错和弹性伸缩)--rdzv_endpoint
: The endpoint where the rendezvous backend is running; usually in form host:port. (--rdzv_endpoint
相当于是取代了之前的--master_addr
和--master_port
. )仍向后兼容原来的启动参数,比如--node_rank
, --master_addr
和--master_port
.
torch/distributed/run.py
torch.distirbuted.run
的入口函数
def run(args):
if args.standalone:
args.rdzv_backend = "c10d"
args.rdzv_endpoint = "localhost:29400"
args.rdzv_id = str(uuid.uuid4())
log.info(
f"\n**************************************\n"
f"Rendezvous info:\n"
f"--rdzv_backend={args.rdzv_backend} "
f"--rdzv_endpoint={args.rdzv_endpoint} "
f"--rdzv_id={args.rdzv_id}\n"
f"**************************************\n"
)
config, cmd, cmd_args = config_from_args(args)
elastic_launch(
config=config,
entrypoint=cmd,
)(*cmd_args)
可以看到,elastic_launch
已成默认的启动方式.
torch/distributed/launcher/api.py
elastic_launch
调用launch_agent
函数,会创建并启动一个agent.
agent = LocalElasticAgent(
spec=spec, start_method=config.start_method, log_dir=config.log_dir
)
result = agent.run()
torch/distributed/elastic/agent/server/api.py
agent的基类是ElasticAgent
,主要暴露run这个API,即启动agent. SimpleElasticAgent
类继承了ElasticAgent
实现了一些功能,并且多了四个abstracmethod等待子类实现. 现在agent使用的是LocalElasticAgent
类,继承自SimpleElasticAgent
类.
SimpleElasticAgent
类
首先有个概念是WorkerGroup
,基本可以等同于是单机上worker的集合,每张卡一个worker.
run
函数逻辑主要在_invoke_run
函数,其流程如下:
_initialize_workers
: 先做rendezvous得到node rank和node world size. 然后启动workers. _start_workers
是abstractmethod. def _initialize_workers(self, worker_group: WorkerGroup) -> None:
r"""
Starts a fresh set of workers for the worker_group.
Essentially a rendezvous followed by a start_workers.
The caller should first call ``_stop_workers()`` to stop running workers
prior to calling this method.
Optimistically sets the state of the worker group that
just started as ``HEALTHY`` and delegates the actual monitoring
of state to ``_monitor_workers()`` method
"""
role = worker_group.spec.role
log.info(f"[{role}] Rendezvous'ing worker group")
self._rendezvous(worker_group)
log.info(f"[{role}] Starting worker group")
worker_ids = self._start_workers(worker_group)
for local_rank, w_id in worker_ids.items():
worker = worker_group.workers[local_rank]
worker.id = w_id
worker_group.state = WorkerState.HEALTHY
如何做rendezvous的逻辑见下一节,提供的接口是调用handler的next_rendezvous
函数.
run_result
. 如果run_result
是SUCCEEDED,代表成功结束; 如果是UNHEALTHY
或者FAILED
,重启workers; 如果是HEALTHY
,如果有waiting nodes,也需要重启workers. while True:
assert self._worker_group.state != WorkerState.INIT
time.sleep(monitor_interval)
run_result = self._monitor_workers(self._worker_group)
state = run_result.state
self._worker_group.state = state
put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
put_metric(f"workers.{role}.{state.name.lower()}", 1)
if state == WorkerState.SUCCEEDED:
log.info(
f"[{role}] worker group successfully finished."
f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish."
)
self._exit_barrier()
return run_result
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
if self._remaining_restarts > 0:
log.info(
f"[{role}] Worker group {state.name}. "
f"{self._remaining_restarts}/{spec.max_restarts} attempts left;"
f" will restart worker group"
)
self._remaining_restarts -= 1
self._restart_workers(self._worker_group)
else:
self._stop_workers(self._worker_group)
self._worker_group.state = WorkerState.FAILED
self._exit_barrier()
return run_result
elif state == WorkerState.HEALTHY:
# membership changes do not count as retries
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
group_rank = self._worker_group.group_rank
if num_nodes_waiting > 0:
log.info(
f"[{role}] Detected {num_nodes_waiting} "
f"new nodes from group_rank={group_rank}; "
f"will restart worker group"
)
self._restart_workers(self._worker_group)
else:
raise Exception(f"[{role}] Worker group in {state.name} state")
_monitor_workers
这个函数也是abstractmethod.
_restart_workers
函数: 先stop workers,然后重新初始化workers.
def _restart_workers(self, worker_group: WorkerGroup) -> None:
"""
Restarts (stops, rendezvous, starts) all local workers in the group.
"""
role = worker_group.spec.role
log.info(f"[{role}] Stopping worker group")
self._stop_workers(worker_group)
worker_group.state = WorkerState.STOPPED
self._initialize_workers(worker_group)
_stop_workers
也是abstractmethod.
torch/distributed/elastic/agent/server/local_elastic_agent.py
看看上面提到的abstractmethod的具体实现.
_start_workers
函数. 启动n个workers (进程),主要是调用start_processes
函数.self._pcontext = start_processes(
name=spec.role,
entrypoint=spec.entrypoint,
args=args,
envs=envs,
log_dir=attempt_log_dir,
start_method=self._start_method,
redirects=spec.redirects,
tee=spec.tee,
)
注意:
entrypoint
和args
对应用户命令和对应参数. entrypoint可以是函数或者str. envs
包括LOCAL_RANK, RANK, GROUP_RANK, GROUP_SIZE等等. 每个worker可以用环境变量的方式得到这些值. start_processes
返回值是一个PContext
,即Process Context,提供了一些对processes的操作API,比如start, close, wait等.
PContext
有两个子类: MultiprocessContext
, SubprocessContext
. 如果entrypoint是函数,使用前者; 如果是str,使用后者.
MultiprocessContext
: 使用torch.multiprocessing
作为实现.SubprocessContext
: 使用subprocess.Popen
作为实现. _stop_workers
函数: 调用PContext
的close操作.
_monitor_workers
函数. 调用PContext
的wait操作,poll下processes的run status.
文档对Rendezvous一词给的解释:
In the context of Torch Distributed Elastic we use the term rendezvous to refer to a particular functionality that combines a distributed synchronization primitive with peer discovery.
It is used by Torch Distributed Elastic to gather participants of a training job (i.e. nodes) such that they all agree on the same list of participants and everyone’s roles, as well as make a consistent collective decision on when training can begin/resume.
看上去是用来实现consensus的一个工具 —— 对于participants达成共识,并分配rank和word_size. 一旦出现故障,或者需要scale up / scale down的时候,会进行re-rendezvous.
distributed/elastic/rendezvous/dynamic_rendezvous.py
rendezvous核心逻辑在DynamicRendezvousHandler
类的next_rendezvous
函数. 执行结果有三种情况:
基本逻辑是先让该node退出(_RendezvousExitOp
),然后再重新加入(_RendezvousJoinOp
). 如下面代码所示:
try:
self._stop_heartbeats()
# Delay the execution for a small random amount of time if this is our
# first run. This will slightly skew the rendezvous attempts across the
# nodes and reduce the load on the backend.
if self._state_holder.state.round == 0:
_delay(seconds=(0, 0.3))
exit_op = _RendezvousExitOp()
join_op = _RendezvousJoinOp()
deadline = self._get_deadline(self._settings.timeout.join)
self._op_executor.run(exit_op, deadline)
self._op_executor.run(join_op, deadline)
self._start_heartbeats()
rank, world_size = self._get_world()
store = self._get_store()
except Exception as e:
self._record(
message=f"{type(e).__name__}: {str(e)}",
node_state=NodeState.FAILED,
)
raise
具体实现上,每个node会维护一个local rendezvous state,包括完成状态、participants和wait list等,具体如下:
class _RendezvousState:
"""Holds the state of a rendezvous.
Attributes:
round:
The current round of the rendezvous.
complete:
A boolean value indicating whether the current round of the
rendezvous is complete.
deadline:
The time at which the current round of the rendezvous will be
considered complete if it is still waiting for nodes to join.
closed:
A boolean value indicating whether the rendezvous is closed.
participants:
A dictionary of the participants and their corresponding ranks.
wait_list:
A set of nodes that are waiting to participate in the next round of
the rendezvous.
last_heartbeats:
A dictionary containing each node's last heartbeat time.
"""
值得一提的是,在执行各种op前,会对state进行一个sync来实现consensus. sync是通过c10d的TCP kv store实现的,即rendezvous backend. 执行op的时候,可能会对state更新. 如果有local write,需要先做一个compare set(返回new state和是否更新了state); 如果没有,直接get.
exit
: 如果node本身就不在participants中,do nothing; 如果在,则在participants列表中删除. class _RendezvousExitOp:
"""Represents a rendezvous exit operation."""
def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
if ctx.node in ctx.state.participants:
if time.monotonic() > deadline:
return _Action.ERROR_TIMEOUT
return _Action.REMOVE_FROM_PARTICIPANTS
return _Action.FINISH
join
: 尝试将该node加入到participants中,可能被加入到wait list中,还可能连wait list都进不去,继续等待. if state.complete:
# If we are here, it means we are not part of the rendezvous. In
# case the rendezvous has capacity for additional participants add
# ourself to the wait list for the next round.
if len(state.participants) < ctx.settings.max_nodes:
if ctx.node not in state.wait_list:
return _Action.ADD_TO_WAIT_LIST
elif is_participant:
# If the rendezvous has enough number of participants including us,
# check whether we have passed the rendezvous deadline. If yes,
# complete it.
if len(state.participants) >= ctx.settings.min_nodes:
if cast(datetime, state.deadline) < datetime.utcnow():
return _Action.MARK_RENDEZVOUS_COMPLETE
else:
# The rendezvous is not complete yet and we are not part of it. Try
# to join.
return _Action.ADD_TO_PARTICIPANTS
keep-alive
: 除了join和exit这两个操作外,还有一个keep-alive操作,用于检测node是否故障. 这个函数会周期性(每5s)进行一次,更新local state的last heartbeats. 在下一次state sync的时候,会将last heartbeats写入到kv store,从而被其他healthy node知晓. def _keep_alive(self) -> None:
msg = (
f"The node '{self._node}' updated its keep-alive heartbeat time for the rendezvous "
f"'{self._settings.run_id}'. Pending sync."
)
self._record(message=msg)
log.debug(msg)
self._state.last_heartbeats[self._node] = datetime.utcnow()
state.last_heartbeats
会被用来过滤dead nodes. 如果上一次heartbeat时间超过了一定范围,则会把这些nodes标记为dead nodes,从state的participant或者wait list中剔除.
可以看到,使用kv store这个设计非常巧妙,很像master-client的设计: master上保存一些metadata,clients会去更新metadata,所有clients都能看到,master定期删除不更新metadata的clients. 设计上的确简化了不少.
之前似乎没仔细看static这个mode. 等下好好看下.
https://pytorch.org/tutorials/beginner/dist_overview.html