pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.99k stars 22.38k forks source link

RFC: Overlap optimizer computation with DDP/FSDP backward #67570

Open rohan-varma opened 2 years ago

rohan-varma commented 2 years ago

with @cbalioglu

Context

Communication/computation overlap is a well-known theme in data parallel training where developers exploit any independence in the forward/backward/optimizer passes of model training and parallelize to increase the performance. Recently there has been some interest in overlapping optimizer computation with DDP backwards to further increase communication/computation overlap.

Initial experiments from a prototype that overlaps SGD optimizer step with DDP backwards has shown ~15% peak memory reduction on a few benchmarks, which enables larger model training or larger batch size with the same amount of GPU memory. This is becase peak memory is reduced since the optimizer only needs to load the state of one parameter while doing step for that parameter, instead of loading all per-parameter optimizer states into lists prior to running step() for all parameters.

Goal

Initially Supported Use Cases

API Proposal

At a high level, user will use a register_fused_optim API to enable optimizer fusing with DDP backwards pass. The following code is a skeleton of the proposed API, written by @cbalioglu:

import types

from abc import ABC

import torch

from torch.distributed._fsdp import FullyShardedDataParallel
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer

from .ddp_comm_hooks.ddp_zero_hook import (
    hook_with_zero_step,
    hook_with_zero_step_interleave,
)
from .ddp_comm_hooks.default_hooks import _OptimizerHookState, allreduce_hook
from .utils import as_functional_optim  # TODO: Implement

# Contains the mappings between the regular and overlapped optimizer types.
_registered_overlapped_optims: Dict[Type, Type] = {}

# Decorator function that registers the mapping between a regular and an overlapped
# optimizer.
def register(overlapped_optim_cls, optim_cls):
    ...

class OverlappedOptimizer(ABC):
    """Represents an overlapped optimizer. Note that this is a public API."""

    def __init__(self, optim: Optimizer) -> None:
        self.optim = optim

        def step(_, closure):
            raise RuntimeError(
                "The `step()` method cannot be called in overlapped mode."
            )

        self.optim.step = types.MethodType(step, optim)

    def register_ddp(self, ddp: DistributedDataParallel) -> None:
        """Registers the overlapped optimizer with DDP."""
        raise NotImplementedError(
            f"{self.__class__.__name__} does not support overlapped DDP."
        )

    def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None:
        """Registers the overlapped optimizer with FSDP."""
        raise NotImplementedError(
            f"{self.__class__.__name__} does not support overlapped FSDP."
        )

class _OverlappedStandardOptimizer(OverlappedOptimizer):
    """Overlaps a regular ``Optimizer``."""

    def __init__(self, optim: Optimizer) -> None:
        super().__init__(optim)

        f_optim = as_functional_optim(optim, allow_empty_param_list=True)

        # Requires a new `__init__` overload that accepts a functional optimizer
        # instance.
        self._hook_state = _OptimizerHookState(f_optim)

    def register_ddp(self, ddp: DistributedDataParallel) -> None:
        ddp.register_comm_hook(
            None, d._hook_then_optimizer(allreduce_hook, self._hook_state)
        )

@register(ZeroRedundancyOptimizer)        
class _OverlappedZeroRedundancyOptimizer(OverlappedOptimizer):
    """Overlaps the ``ZeroRedundancyOptimizer``."""

    def __init__(
        self,
        optim: ZeroRedundancyOptimizer,
        interleaved: bool = False,
        shard_buckets: bool = False,
    ) -> None:
        super().__init__(optim)

        if interleaved:
            self._hook_f = hook_with_zero_step_interleave
        else:
            self._hook_f = hook_with_zero_step

        self._shard_buckets = shard_buckets

    def register_ddp(self, ddp: DistributedDataParallel) -> None:
        ddp.register_comm_hook(
            None, self._hook_f(allreduce_hook, ddp, self.optim, self._shard_buckets)
        )

def _as_overlapped_optim(
    optim: Optimizer, fused_args: Sequence[Any], fused_kwargs: Dict[str, Any]
) -> _OverlappedOptim:
    """Returns a new ``OverlappedOptimizer`` instance that supports ``optim``."""
    for optim_cls, overlapped_optim_cls in _registered_overlapped_optims:
        if isinstance(optim, optim_cls):
            return overlapped_optim_cls(optim, *fused_args, **fused_kwargs)

    # Fallback to the standard overlapped optimizer.       
    return _OverlappedStandardOptimizer(optim)

class DistributedDataParallel:
    def register_fused_optim(self, optim: Optimizer, *fused_args, **fused_kwargs):
        overlapped_optim = _as_overlapped_optim(optim, fused_args, fused_kwargs)

        try:
            overlapped_optim.register_ddp(self)
        except NotImplementedError:
            raise RuntimeError(
                f"{optim.__class__.__name__} does not support overlapped DDP."
            )

    # Rest of DDP...

class FullyShardedDataParallel:
    def register_fused_optim(self, optim: Optimizer, *fused_args, **fused_kwargs):
        overlapped_optim = _as_overlapped_optim(optim, fused_args, fused_kwargs)

        try:
            overlapped_optim.register_fsdp(self)
        except NotImplementedError:
            raise RuntimeError(
                f"{optim.__class__.__name__} does not support overlapped FSDP."
            )

    # Rest of FSDP...

Outstanding Questions

Misc Details

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang

byronyi commented 2 years ago

Are you still working on this? @rohan-varma