Closed yhna940 closed 3 months ago
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Thanks! Overall this looks great to me, a few requests though:
- Can we add an example in
examples/by_feature
specifically showcasing this usage- Can we add some documentation to the official docs on it. Ideally a
Usage Guide
if nothing else (though aConcept Guide
too would be ideal!)
Thank you for the review and suggestions :)
In response to your feedback, I have made the following updates:
examples/by_feature/ddp_comm_hook.py
.docs/source/usage_guides/ddp_comm_hook.md
.These additions aim to provide clear guidance on how to utilize DDP communication hooks with the š¤ Accelerate library, enhancing the usability and performance of distributed training.
Please let me know if there are any further adjustments or additions required.
Thank you for the feedback and the suggestion @SunMarc. I have added more details about comm_wrapper
and comm_state_option
in the documentation. As you mentioned, comm_state_option
is currently only applied to PowerSGD
in this PR.
However, there are other state-using options such as post_localSGD_hook
available in PyTorch. These hooks require specific optimizers and provide advanced features for gradient communication. For example, post_localSGD_hook
is closely tied to the PostLocalSGDOptimizer
and involves additional setup:
import torch
import torch.distributed as dist
import torch.distributed.algorithms.model_averaging.averagers as averagers
import torch.nn as nn
from torch.distributed.optim import PostLocalSGDOptimizer
from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
PostLocalSGDState,
post_localSGD_hook,
)
model = nn.parallel.DistributedDataParallel(
module, device_ids=[rank], output_device=rank
)
# Register a post-localSGD communication hook.
state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
model.register_comm_hook(state, post_localSGD_hook)
# Create a post-localSGD optimizer that wraps a local optimizer.
# Note that `warmup_steps` used in `PostLocalSGDOptimizer` must be the same as
# `start_localSGD_iter` used in `PostLocalSGDState`.
local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
opt = PostLocalSGDOptimizer(
optim=local_optim,
averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
)
# In the first 100 steps, DDP runs global gradient averaging at every step.
# After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default),
# and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer.
for step in range(0, 200):
opt.zero_grad()
loss = loss_fn(output, labels)
loss.backward()
opt.step()
These advanced hooks were not included in this PR because the PyTorch official documentation primarily highlights PowerSGD, FP16, and BF16 hooks. You can find more information about additional hooks in the PyTorch DDP Communication Hooks documentation and the PyTorch GitHub repository.
Thanks for the quick review @SunMarc , I've done it :)
Thank you for your detailed review @stevhliu :) I've changed the guide as you suggested, I used hfoption
option tag and rearranged some sentences, I'll fix it if there's anything wrong, thanks š¤
Nicely done @yhna940!
What does this PR do?
This PR adds support for DDP communication hooks to the
accelerate
library. Similar to frameworks like PyTorch Lightning and Detectron, these hooks provide an interface to control how gradients are communicated across workers, overriding the standard allreduce in DistributedDataParallel. This feature enables the use of performance-improving communication hooks when using multiple nodes.Motivation and Context
DDP communication hooks allow users to customize and optimize gradient communication, potentially improving training performance in distributed settings.
Based on the official PyTorch documentation here, I've implemented three default hooks: PowerSGD, FP16, and BF16. These hooks provide performance improvements in distributed training scenarios.
The implementation for registering these hooks was inspired by the PyTorch Lightning implementation, which can be found here.
Fixes # (issue)
N/A
Before submitting