jasperzhong / cs-notes

CS认知体系
6 stars 0 forks source link

PyTorch Distributed #22

Closed jasperzhong closed 2 years ago

jasperzhong commented 2 years ago

https://pytorch.org/tutorials/beginner/dist_overview.html

jasperzhong commented 2 years ago

看上去torch.distributed主要分成两个components:

DDP可以简单看下他们的design. 不需花太多时间. 重点看下RPC和torchelastic.

jasperzhong commented 2 years ago

https://pytorch.org/docs/stable/notes/ddp.html

DDP

没啥特别的. gradients会分组,叫做bucket,默认大小是25MB. bucket里面所有gradient都计算好后,开始fire异步allreduce. 所有buckets都fire后,阻塞等待所有allreduce操作结束. 和Horovod的设计没有什么两样.

实现上,主要是两块:

image


居然还发了一篇VLDB '20 http://www.vldb.org/pvldb/vol13/p3005-li.pdf

jasperzhong commented 2 years ago

https://pytorch.org/docs/master/rpc.html

RPC

主要components:

jasperzhong commented 2 years ago

https://pytorch.org/docs/master/rpc/rref.html

RRef

分两种RRef:

创建UserRRef有三种情况:

  1. Receiving a UserRRef from the owner.
  2. Receiving a UserRRef from another user.
  3. Creating a new UserRRef owned by another worker.

设计上要保证两点:

关于G2文档解释很好.

In cases 2 and 3, it is possible that the owner has only partial or no knowledge at all about the RRef fork graph. For example, an RRef could be constructed on a user, and before the owner receives any RPC call, the creator user might have already shared the RRef with other users, and those users could further share the RRef. One invariant is that the fork graph of any RRef is always a tree, because forking an RRef always creates a new UserRRef instance on the callee (except if the callee is the owner), and hence every RRef has a single parent.

确实,这种Owner/Borrower模式产生的就是一个tree. 这和fork tree非常像! 所以这里把创建RRef称作"Forking an RRef". 注意这个tree的root不一定是OwnerRRef,完全可以是UserRRef,这取决于程序调用结构.

关于为什么需要G2, 可以看文档里面的讨论, 很受启发. 在这里不赘述.

看个例子.

import torch
import torch.distributed.rpc as rpc

# on worker A and worker C
def func(rref):
  pass

# on worker A
rref = rpc.remote('B', torch.add, args=(torch.ones(2), 1))
# say the rref has RRefId 100 and ForkId 1
rpc.rpc_async('C', func, args=(rref, ))

image

其中{100, 1} 分别是 RRefIdForkId.

jasperzhong commented 2 years ago

https://pytorch.org/docs/master/rpc/distributed_autograd.html

Distributed Autograd Design

看下这个代码. 确实很灵活.

import torch
import torch.multiprocessing as mp
import torch.distributed.autograd as dist_autograd
from torch.distributed import rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer

def random_tensor():
    return torch.rand((3, 3), requires_grad=True)

def _run_process(rank, dst_rank, world_size):
    name = "worker{}".format(rank)
    dst_name = "worker{}".format(dst_rank)

    # Initialize RPC.
    rpc.init_rpc(
        name=name,
        rank=rank,
        world_size=world_size
    )

    # Use a distributed autograd context.
    with dist_autograd.context() as context_id:
        # Forward pass (create references on remote nodes).
        rref1 = rpc.remote(dst_name, random_tensor)
        rref2 = rpc.remote(dst_name, random_tensor)
        loss = rref1.to_here() + rref2.to_here()

        # Backward pass (run distributed autograd).
        dist_autograd.backward(context_id, [loss.sum()])

        # Build DistributedOptimizer.
        dist_optim = DistributedOptimizer(
        optim.SGD,
        [rref1, rref2],
        lr=0.05,
        )

        # Run the distributed optimizer step.
        dist_optim.step(context_id)

def run_process(rank, world_size):
    dst_rank = (rank + 1) % world_size
    _run_process(rank, dst_rank, world_size)
    rpc.shutdown()

if __name__ == '__main__':
  # Run world_size workers
  world_size = 2
  mp.spawn(run_process, args=(world_size,), nprocs=world_size)
jasperzhong commented 2 years ago

先到此为止. 我感觉RPC主要还是用来做控制.

jasperzhong commented 2 years ago

https://pytorch.org/docs/stable/distributed.html

jasperzhong commented 2 years ago

PyTorch Elastic

依靠checkpoint恢复状态. 脚本的逻辑最好是一开始就load checkpoint.

def main():
  load_checkpoint(checkpoint_path)
  initialize()
  train()

def train():
  for batch in iter(dataset):
    train_step(batch)

    if should_checkpoint:
      save_checkpoint(checkpoint_path)

这是一个load_checkpoint函数的例子. 写的不错.

https://github.com/pytorch/elastic/blob/master/examples/imagenet/main.py#L315-L391

def load_checkpoint(
    checkpoint_file: str,
    device_id: int,
    arch: str,
    model: DistributedDataParallel,
    optimizer,  # SGD
) -> State:
    """
    Loads a local checkpoint (if any). Otherwise, checks to see if any of
    the neighbors have a non-zero state. If so, restore the state
    from the rank that has the most up-to-date checkpoint.
    .. note:: when your job has access to a globally visible persistent storage
              (e.g. nfs mount, S3) you can simply have all workers load
              from the most recent checkpoint from such storage. Since this
              example is expected to run on vanilla hosts (with no shared
              storage) the checkpoints are written to local disk, hence
              we have the extra logic to broadcast the checkpoint from a
              surviving node.
    """

    state = State(arch, model, optimizer)

    if os.path.isfile(checkpoint_file):
        print(f"=> loading checkpoint file: {checkpoint_file}")
        state.load(checkpoint_file, device_id)
        print(f"=> loaded checkpoint file: {checkpoint_file}")

    # logic below is unnecessary when the checkpoint is visible on all nodes!
    # create a temporary cpu pg to broadcast most up-to-date checkpoint
    with tmp_process_group(backend="gloo") as pg:
        rank = dist.get_rank(group=pg)

        # get rank that has the largest state.epoch
        epochs = torch.zeros(dist.get_world_size(), dtype=torch.int32)
        epochs[rank] = state.epoch
        dist.all_reduce(epochs, op=dist.ReduceOp.SUM, group=pg)
        t_max_epoch, t_max_rank = torch.max(epochs, dim=0)
        max_epoch = t_max_epoch.item()
        max_rank = t_max_rank.item()

        # max_epoch == -1 means no one has checkpointed return base state
        if max_epoch == -1:
            print(f"=> no workers have checkpoints, starting from epoch 0")
            return state

        # broadcast the state from max_rank (which has the most up-to-date state)
        # pickle the snapshot, convert it into a byte-blob tensor
        # then broadcast it, unpickle it and apply the snapshot
        print(f"=> using checkpoint from rank: {max_rank}, max_epoch: {max_epoch}")

        with io.BytesIO() as f:
            torch.save(state.capture_snapshot(), f)
            raw_blob = numpy.frombuffer(f.getvalue(), dtype=numpy.uint8)

        blob_len = torch.tensor(len(raw_blob))
        dist.broadcast(blob_len, src=max_rank, group=pg)
        print(f"=> checkpoint broadcast size is: {blob_len}")

        if rank != max_rank:
            blob = torch.zeros(blob_len.item(), dtype=torch.uint8)
        else:
            blob = torch.as_tensor(raw_blob, dtype=torch.uint8)

        dist.broadcast(blob, src=max_rank, group=pg)
        print(f"=> done broadcasting checkpoint")

        if rank != max_rank:
            with io.BytesIO(blob.numpy()) as f:
                snapshot = torch.load(f)
            state.apply_snapshot(snapshot, device_id)

        # wait till everyone has loaded the checkpoint
        dist.barrier(group=pg)

    print(f"=> done restoring from previous checkpoint")
    return state
jasperzhong commented 2 years ago

PyTorch v1.9开始原生支持elastic功能,已将torchelastic项目整合到上游.

API上的变化是推荐使用torch.distributed.run. 从功能上看,torch.distributed.run是原来torch.distributed.launch的超集,提供了容错和弹性伸缩功能.

现在推荐的使用方式是:

>>> python -m torch.distributed.run
    --nnodes=$NUM_NODES
    --nproc_per_node=$NUM_TRAINERS
    --rdzv_id=$JOB_ID
    --rdzv_backend=c10d
    --rdzv_endpoint=$HOST_NODE_ADDR
    YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

有三个新参数,文档的解释是:

  1. --rdzv_id: A unique job id (shared by all nodes participating in the job)
  2. --rdzv_backend: An implementation of torch.distributed.elastic.rendezvous.RendezvousHandler. (--rdzv_backend默认是static,即不支持容错和弹性伸缩)
  3. --rdzv_endpoint: The endpoint where the rendezvous backend is running; usually in form host:port. (--rdzv_endpoint相当于是取代了之前的--master_addr--master_port. )

仍向后兼容原来的启动参数,比如--node_rank, --master_addr--master_port.

Elastic Launch

torch/distributed/run.py

torch.distirbuted.run的入口函数

def run(args):
    if args.standalone:
        args.rdzv_backend = "c10d"
        args.rdzv_endpoint = "localhost:29400"
        args.rdzv_id = str(uuid.uuid4())
        log.info(
            f"\n**************************************\n"
            f"Rendezvous info:\n"
            f"--rdzv_backend={args.rdzv_backend} "
            f"--rdzv_endpoint={args.rdzv_endpoint} "
            f"--rdzv_id={args.rdzv_id}\n"
            f"**************************************\n"
        )

    config, cmd, cmd_args = config_from_args(args)
    elastic_launch(
        config=config,
        entrypoint=cmd,
    )(*cmd_args)

可以看到,elastic_launch已成默认的启动方式.

torch/distributed/launcher/api.py

elastic_launch调用launch_agent函数,会创建并启动一个agent.

        agent = LocalElasticAgent(
            spec=spec, start_method=config.start_method, log_dir=config.log_dir
        )

        result = agent.run()

Elastic Agent

torch/distributed/elastic/agent/server/api.py

agent的基类是ElasticAgent,主要暴露run这个API,即启动agent. SimpleElasticAgent类继承了ElasticAgent实现了一些功能,并且多了四个abstracmethod等待子类实现. 现在agent使用的是LocalElasticAgent类,继承自SimpleElasticAgent类.

SimpleElasticAgent

首先有个概念是WorkerGroup,基本可以等同于是单机上worker的集合,每张卡一个worker.

run函数逻辑主要在_invoke_run函数,其流程如下:

  1. _initialize_workers: 先做rendezvous得到node rank和node world size. 然后启动workers. _start_workers是abstractmethod.
    def _initialize_workers(self, worker_group: WorkerGroup) -> None:
        r"""
        Starts a fresh set of workers for the worker_group.
        Essentially a rendezvous followed by a start_workers.

        The caller should first call ``_stop_workers()`` to stop running workers
        prior to calling this method.

        Optimistically sets the state of the worker group that
        just started as ``HEALTHY`` and delegates the actual monitoring
        of state to ``_monitor_workers()`` method
        """
        role = worker_group.spec.role
        log.info(f"[{role}] Rendezvous'ing worker group")

        self._rendezvous(worker_group)

        log.info(f"[{role}] Starting worker group")
        worker_ids = self._start_workers(worker_group)
        for local_rank, w_id in worker_ids.items():
            worker = worker_group.workers[local_rank]
            worker.id = w_id

        worker_group.state = WorkerState.HEALTHY

如何做rendezvous的逻辑见下一节,提供的接口是调用handler的next_rendezvous函数.

  1. monitor: 每30s monitor一次workers,得到run_result. 如果run_result是SUCCEEDED,代表成功结束; 如果是UNHEALTHY或者FAILED,重启workers; 如果是HEALTHY,如果有waiting nodes,也需要重启workers.
        while True:
            assert self._worker_group.state != WorkerState.INIT
            time.sleep(monitor_interval)
            run_result = self._monitor_workers(self._worker_group)
            state = run_result.state
            self._worker_group.state = state

            put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
            put_metric(f"workers.{role}.{state.name.lower()}", 1)

            if state == WorkerState.SUCCEEDED:
                log.info(
                    f"[{role}] worker group successfully finished."
                    f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish."
                )
                self._exit_barrier()
                return run_result
            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
                if self._remaining_restarts > 0:
                    log.info(
                        f"[{role}] Worker group {state.name}. "
                        f"{self._remaining_restarts}/{spec.max_restarts} attempts left;"
                        f" will restart worker group"
                    )
                    self._remaining_restarts -= 1
                    self._restart_workers(self._worker_group)
                else:
                    self._stop_workers(self._worker_group)
                    self._worker_group.state = WorkerState.FAILED
                    self._exit_barrier()
                    return run_result
            elif state == WorkerState.HEALTHY:
                # membership changes do not count as retries
                num_nodes_waiting = rdzv_handler.num_nodes_waiting()
                group_rank = self._worker_group.group_rank
                if num_nodes_waiting > 0:
                    log.info(
                        f"[{role}] Detected {num_nodes_waiting} "
                        f"new nodes from group_rank={group_rank}; "
                        f"will restart worker group"
                    )
                    self._restart_workers(self._worker_group)
            else:
                raise Exception(f"[{role}] Worker group in {state.name} state")

_monitor_workers这个函数也是abstractmethod.

_restart_workers函数: 先stop workers,然后重新初始化workers.

    def _restart_workers(self, worker_group: WorkerGroup) -> None:
        """
        Restarts (stops, rendezvous, starts) all local workers in the group.
        """

        role = worker_group.spec.role
        log.info(f"[{role}] Stopping worker group")
        self._stop_workers(worker_group)
        worker_group.state = WorkerState.STOPPED
        self._initialize_workers(worker_group)

_stop_workers也是abstractmethod.

torch/distributed/elastic/agent/server/local_elastic_agent.py

看看上面提到的abstractmethod的具体实现.

  1. _start_workers函数. 启动n个workers (进程),主要是调用start_processes函数.
self._pcontext = start_processes(
            name=spec.role,
            entrypoint=spec.entrypoint,
            args=args,
            envs=envs,
            log_dir=attempt_log_dir,
            start_method=self._start_method,
            redirects=spec.redirects,
            tee=spec.tee,
        )

注意:

start_processes返回值是一个PContext,即Process Context,提供了一些对processes的操作API,比如start, close, wait等.

PContext有两个子类: MultiprocessContext, SubprocessContext. 如果entrypoint是函数,使用前者; 如果是str,使用后者.

  1. _stop_workers函数: 调用PContext的close操作.

  2. _monitor_workers函数. 调用PContext的wait操作,poll下processes的run status.

Rendezvous

文档对Rendezvous一词给的解释:

In the context of Torch Distributed Elastic we use the term rendezvous to refer to a particular functionality that combines a distributed synchronization primitive with peer discovery.

It is used by Torch Distributed Elastic to gather participants of a training job (i.e. nodes) such that they all agree on the same list of participants and everyone’s roles, as well as make a consistent collective decision on when training can begin/resume.

看上去是用来实现consensus的一个工具 —— 对于participants达成共识,并分配rank和word_size. 一旦出现故障,或者需要scale up / scale down的时候,会进行re-rendezvous.

distributed/elastic/rendezvous/dynamic_rendezvous.py

rendezvous核心逻辑在DynamicRendezvousHandler类的next_rendezvous函数. 执行结果有三种情况:

基本逻辑是先让该node退出(_RendezvousExitOp),然后再重新加入(_RendezvousJoinOp). 如下面代码所示:

        try:
            self._stop_heartbeats()

            # Delay the execution for a small random amount of time if this is our
            # first run. This will slightly skew the rendezvous attempts across the
            # nodes and reduce the load on the backend.
            if self._state_holder.state.round == 0:
                _delay(seconds=(0, 0.3))

            exit_op = _RendezvousExitOp()
            join_op = _RendezvousJoinOp()

            deadline = self._get_deadline(self._settings.timeout.join)

            self._op_executor.run(exit_op, deadline)
            self._op_executor.run(join_op, deadline)

            self._start_heartbeats()

            rank, world_size = self._get_world()
            store = self._get_store()

        except Exception as e:
            self._record(
                message=f"{type(e).__name__}: {str(e)}",
                node_state=NodeState.FAILED,
            )
            raise

具体实现上,每个node会维护一个local rendezvous state,包括完成状态、participants和wait list等,具体如下:

class _RendezvousState:
    """Holds the state of a rendezvous.

    Attributes:
        round:
            The current round of the rendezvous.
        complete:
            A boolean value indicating whether the current round of the
            rendezvous is complete.
        deadline:
            The time at which the current round of the rendezvous will be
            considered complete if it is still waiting for nodes to join.
        closed:
            A boolean value indicating whether the rendezvous is closed.
        participants:
            A dictionary of the participants and their corresponding ranks.
        wait_list:
            A set of nodes that are waiting to participate in the next round of
            the rendezvous.
        last_heartbeats:
            A dictionary containing each node's last heartbeat time.
    """

值得一提的是,在执行各种op前,会对state进行一个sync来实现consensus. sync是通过c10d的TCP kv store实现的,即rendezvous backend. 执行op的时候,可能会对state更新. 如果有local write,需要先做一个compare set(返回new state和是否更新了state); 如果没有,直接get.

  1. exit: 如果node本身就不在participants中,do nothing; 如果在,则在participants列表中删除.
class _RendezvousExitOp:
    """Represents a rendezvous exit operation."""

    def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
        if ctx.node in ctx.state.participants:
            if time.monotonic() > deadline:
                return _Action.ERROR_TIMEOUT
            return _Action.REMOVE_FROM_PARTICIPANTS
        return _Action.FINISH
  1. join: 尝试将该node加入到participants中,可能被加入到wait list中,还可能连wait list都进不去,继续等待.
      if state.complete:
            # If we are here, it means we are not part of the rendezvous. In
            # case the rendezvous has capacity for additional participants add
            # ourself to the wait list for the next round.
            if len(state.participants) < ctx.settings.max_nodes:
                if ctx.node not in state.wait_list:
                    return _Action.ADD_TO_WAIT_LIST
        elif is_participant:
            # If the rendezvous has enough number of participants including us,
            # check whether we have passed the rendezvous deadline. If yes,
            # complete it.
            if len(state.participants) >= ctx.settings.min_nodes:
                if cast(datetime, state.deadline) < datetime.utcnow():
                    return _Action.MARK_RENDEZVOUS_COMPLETE
        else:
            # The rendezvous is not complete yet and we are not part of it. Try
            # to join.
            return _Action.ADD_TO_PARTICIPANTS
  1. keep-alive: 除了join和exit这两个操作外,还有一个keep-alive操作,用于检测node是否故障. 这个函数会周期性(每5s)进行一次,更新local state的last heartbeats. 在下一次state sync的时候,会将last heartbeats写入到kv store,从而被其他healthy node知晓.
    def _keep_alive(self) -> None:
        msg = (
            f"The node '{self._node}' updated its keep-alive heartbeat time for the rendezvous "
            f"'{self._settings.run_id}'. Pending sync."
        )
        self._record(message=msg)
        log.debug(msg)

        self._state.last_heartbeats[self._node] = datetime.utcnow()

state.last_heartbeats会被用来过滤dead nodes. 如果上一次heartbeat时间超过了一定范围,则会把这些nodes标记为dead nodes,从state的participant或者wait list中剔除.

可以看到,使用kv store这个设计非常巧妙,很像master-client的设计: master上保存一些metadata,clients会去更新metadata,所有clients都能看到,master定期删除不更新metadata的clients. 设计上的确简化了不少.

jasperzhong commented 2 years ago

pytorch elastic

jasperzhong commented 2 years ago

image

jasperzhong commented 2 years ago

之前似乎没仔细看static这个mode. 等下好好看下.