Communication/computation overlap is a well-known theme in data parallel training where developers exploit any independence in the forward/backward/optimizer passes of model training and parallelize to increase the performance. Recently there has been some interest in overlapping optimizer computation with DDP backwards to further increase communication/computation overlap.
Initial experiments from a prototype that overlaps SGD optimizer step with DDP backwards has shown ~15% peak memory reduction on a few benchmarks, which enables larger model training or larger batch size with the same amount of GPU memory. This is becase peak memory is reduced since the optimizer only needs to load the state of one parameter while doing step for that parameter, instead of loading all per-parameter optimizer states into lists prior to running step() for all parameters.
Goal
Enable optimizer overlap with DDP backwards as a prototype feature in Pytorch, after which it can be used to experiment on various workloads with the goal of memory savings and increased efficiency.
Further investigate ZeRO optimizer to see if ZeRO communication can be overlapped with DDP backwards, reducing some of the communication cost of ZeRO.
Initially Supported Use Cases
Overlap regular (non-distributed) optimizers with DDP backward, such as SGD, Adam, Adagrad.
Support optimizer overlap for ZeRORedundancyOptimizer (note that some prior work on this has already been done).
API Proposal
At a high level, user will use a register_fused_optim API to enable optimizer fusing with DDP backwards pass. The following code is a skeleton of the proposed API, written by @cbalioglu:
import types
from abc import ABC
import torch
from torch.distributed._fsdp import FullyShardedDataParallel
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer
from .ddp_comm_hooks.ddp_zero_hook import (
hook_with_zero_step,
hook_with_zero_step_interleave,
)
from .ddp_comm_hooks.default_hooks import _OptimizerHookState, allreduce_hook
from .utils import as_functional_optim # TODO: Implement
# Contains the mappings between the regular and overlapped optimizer types.
_registered_overlapped_optims: Dict[Type, Type] = {}
# Decorator function that registers the mapping between a regular and an overlapped
# optimizer.
def register(overlapped_optim_cls, optim_cls):
...
class OverlappedOptimizer(ABC):
"""Represents an overlapped optimizer. Note that this is a public API."""
def __init__(self, optim: Optimizer) -> None:
self.optim = optim
def step(_, closure):
raise RuntimeError(
"The `step()` method cannot be called in overlapped mode."
)
self.optim.step = types.MethodType(step, optim)
def register_ddp(self, ddp: DistributedDataParallel) -> None:
"""Registers the overlapped optimizer with DDP."""
raise NotImplementedError(
f"{self.__class__.__name__} does not support overlapped DDP."
)
def register_fsdp(self, fsdp: FullyShardedDataParallel) -> None:
"""Registers the overlapped optimizer with FSDP."""
raise NotImplementedError(
f"{self.__class__.__name__} does not support overlapped FSDP."
)
class _OverlappedStandardOptimizer(OverlappedOptimizer):
"""Overlaps a regular ``Optimizer``."""
def __init__(self, optim: Optimizer) -> None:
super().__init__(optim)
f_optim = as_functional_optim(optim, allow_empty_param_list=True)
# Requires a new `__init__` overload that accepts a functional optimizer
# instance.
self._hook_state = _OptimizerHookState(f_optim)
def register_ddp(self, ddp: DistributedDataParallel) -> None:
ddp.register_comm_hook(
None, d._hook_then_optimizer(allreduce_hook, self._hook_state)
)
@register(ZeroRedundancyOptimizer)
class _OverlappedZeroRedundancyOptimizer(OverlappedOptimizer):
"""Overlaps the ``ZeroRedundancyOptimizer``."""
def __init__(
self,
optim: ZeroRedundancyOptimizer,
interleaved: bool = False,
shard_buckets: bool = False,
) -> None:
super().__init__(optim)
if interleaved:
self._hook_f = hook_with_zero_step_interleave
else:
self._hook_f = hook_with_zero_step
self._shard_buckets = shard_buckets
def register_ddp(self, ddp: DistributedDataParallel) -> None:
ddp.register_comm_hook(
None, self._hook_f(allreduce_hook, ddp, self.optim, self._shard_buckets)
)
def _as_overlapped_optim(
optim: Optimizer, fused_args: Sequence[Any], fused_kwargs: Dict[str, Any]
) -> _OverlappedOptim:
"""Returns a new ``OverlappedOptimizer`` instance that supports ``optim``."""
for optim_cls, overlapped_optim_cls in _registered_overlapped_optims:
if isinstance(optim, optim_cls):
return overlapped_optim_cls(optim, *fused_args, **fused_kwargs)
# Fallback to the standard overlapped optimizer.
return _OverlappedStandardOptimizer(optim)
class DistributedDataParallel:
def register_fused_optim(self, optim: Optimizer, *fused_args, **fused_kwargs):
overlapped_optim = _as_overlapped_optim(optim, fused_args, fused_kwargs)
try:
overlapped_optim.register_ddp(self)
except NotImplementedError:
raise RuntimeError(
f"{optim.__class__.__name__} does not support overlapped DDP."
)
# Rest of DDP...
class FullyShardedDataParallel:
def register_fused_optim(self, optim: Optimizer, *fused_args, **fused_kwargs):
overlapped_optim = _as_overlapped_optim(optim, fused_args, fused_kwargs)
try:
overlapped_optim.register_fsdp(self)
except NotImplementedError:
raise RuntimeError(
f"{optim.__class__.__name__} does not support overlapped FSDP."
)
# Rest of FSDP...
Outstanding Questions
What if user still calls step() in training loop?
Document that they should not call step(), or patch optimizer.step() to raise ValueError if called by the user, or patch optimizer.step() to be a no-op
The user can specify only a subset of parameters into optimizer, or have multiple optimizers optimizing a possibly overlapping subset of model parameters, how will this be supported?
DDP's register_fused_optim will maintain a per-optimizer set of parameters that it optimizes, taken from the optimizer's self.parameters field. Then during communication hook's execution, we'll simply check if the parameter corresponding to the current gradient in question should be optimized by the optimizer.
Can we remove Reducer::copy_bucket_to_grad in Reducer::finalize_backward
For now, the agreement has been that when this is enabled, we can disable this as it just copies the reduced gradient to param.grad field, but param.grad field won't be used during optimizer. Initially this feature will not support gradient accumulation, no_sync, or other uses of param.grad field, but we can consider these if there is a concrete use case. We will also not populate the param.grad field.
One other option is to enforce that the user only uses this feature with grad_as_bucket_view.
Misc Details
For now, we will prioritize building this out for DDP, but can also easily extend it to FSDP, but FSDP needs some prerequisites such as communication hook.
For now, this feature will not support no_sync mode (basically gradients are synchronized every iteration), or gradient accumulation.
A request from FairScale team is to extend this to support _multi_tensor optimizers.
with @cbalioglu
Context
Communication/computation overlap is a well-known theme in data parallel training where developers exploit any independence in the forward/backward/optimizer passes of model training and parallelize to increase the performance. Recently there has been some interest in overlapping optimizer computation with DDP backwards to further increase communication/computation overlap.
Initial experiments from a prototype that overlaps SGD optimizer step with DDP backwards has shown ~15% peak memory reduction on a few benchmarks, which enables larger model training or larger batch size with the same amount of GPU memory. This is becase peak memory is reduced since the optimizer only needs to load the state of one parameter while doing
step
for that parameter, instead of loading all per-parameter optimizer states into lists prior to runningstep()
for all parameters.Goal
Initially Supported Use Cases
API Proposal
At a high level, user will use a
register_fused_optim
API to enable optimizer fusing with DDP backwards pass. The following code is a skeleton of the proposed API, written by @cbalioglu:Outstanding Questions
What if user still calls step() in training loop?
The user can specify only a subset of parameters into optimizer, or have multiple optimizers optimizing a possibly overlapping subset of model parameters, how will this be supported?
register_fused_optim
will maintain a per-optimizer set of parameters that it optimizes, taken from the optimizer'sself.parameters
field. Then during communication hook's execution, we'll simply check if the parameter corresponding to the current gradient in question should be optimized by the optimizer.Can we remove
Reducer::copy_bucket_to_grad
inReducer::finalize_backward
no_sync
, or other uses of param.grad field, but we can consider these if there is a concrete use case. We will also not populate theparam.grad
field.grad_as_bucket_view
.Misc Details
no_sync
mode (basically gradients are synchronized every iteration), or gradient accumulation._multi_tensor
optimizers.cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang