huggingface / accelerate

šŸš€ A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.77k stars 941 forks source link

Add DDP Communication Hooks #2841

Closed yhna940 closed 3 months ago

yhna940 commented 3 months ago

What does this PR do?

This PR adds support for DDP communication hooks to the accelerate library. Similar to frameworks like PyTorch Lightning and Detectron, these hooks provide an interface to control how gradients are communicated across workers, overriding the standard allreduce in DistributedDataParallel. This feature enables the use of performance-improving communication hooks when using multiple nodes.

Motivation and Context

DDP communication hooks allow users to customize and optimize gradient communication, potentially improving training performance in distributed settings.

Based on the official PyTorch documentation here, I've implemented three default hooks: PowerSGD, FP16, and BF16. These hooks provide performance improvements in distributed training scenarios.

The implementation for registering these hooks was inspired by the PyTorch Lightning implementation, which can be found here.

Fixes # (issue)

N/A

Before submitting

HuggingFaceDocBuilderDev commented 3 months ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

yhna940 commented 3 months ago

Thanks! Overall this looks great to me, a few requests though:

  1. Can we add an example in examples/by_feature specifically showcasing this usage
  2. Can we add some documentation to the official docs on it. Ideally a Usage Guide if nothing else (though a Concept Guide too would be ideal!)

Thank you for the review and suggestions :)

In response to your feedback, I have made the following updates:

  1. Added an example specifically showcasing the usage of DDP communication hooks in examples/by_feature/ddp_comm_hook.py.
  2. Created a detailed usage guide for DDP communication hooks, which is now available in the official documentation under docs/source/usage_guides/ddp_comm_hook.md.

These additions aim to provide clear guidance on how to utilize DDP communication hooks with the šŸ¤— Accelerate library, enhancing the usability and performance of distributed training.

Please let me know if there are any further adjustments or additions required.

yhna940 commented 3 months ago

Thank you for the feedback and the suggestion @SunMarc. I have added more details about comm_wrapper and comm_state_option in the documentation. As you mentioned, comm_state_option is currently only applied to PowerSGD in this PR.

However, there are other state-using options such as post_localSGD_hook available in PyTorch. These hooks require specific optimizers and provide advanced features for gradient communication. For example, post_localSGD_hook is closely tied to the PostLocalSGDOptimizer and involves additional setup:

import torch
import torch.distributed as dist
import torch.distributed.algorithms.model_averaging.averagers as averagers
import torch.nn as nn
from torch.distributed.optim import PostLocalSGDOptimizer
from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
    PostLocalSGDState,
    post_localSGD_hook,
)

model = nn.parallel.DistributedDataParallel(
    module, device_ids=[rank], output_device=rank
)

# Register a post-localSGD communication hook.
state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
model.register_comm_hook(state, post_localSGD_hook)

# Create a post-localSGD optimizer that wraps a local optimizer.
# Note that `warmup_steps` used in `PostLocalSGDOptimizer` must be the same as
# `start_localSGD_iter` used in `PostLocalSGDState`.
local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
opt = PostLocalSGDOptimizer(
    optim=local_optim,
    averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
)

# In the first 100 steps, DDP runs global gradient averaging at every step.
# After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default),
# and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer.
for step in range(0, 200):
    opt.zero_grad()
    loss = loss_fn(output, labels)
    loss.backward()
    opt.step()

These advanced hooks were not included in this PR because the PyTorch official documentation primarily highlights PowerSGD, FP16, and BF16 hooks. You can find more information about additional hooks in the PyTorch DDP Communication Hooks documentation and the PyTorch GitHub repository.

yhna940 commented 3 months ago

Thanks for the quick review @SunMarc , I've done it :)

yhna940 commented 3 months ago

Thank you for your detailed review @stevhliu :) I've changed the guide as you suggested, I used hfoption option tag and rearranged some sentences, I'll fix it if there's anything wrong, thanks šŸ¤—

muellerzr commented 3 months ago

Nicely done @yhna940!