NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.42k stars 1.4k forks source link

Avoid unnecessary param write in distributed Adam kernel #1795

Closed timmoon10 closed 7 months ago

timmoon10 commented 7 months ago

The distributed Adam kernel is implemented to simultaneously output to the master param buffer (usually FP32) and to the param all-gather buffer (usually BF16). This saves the cost of doing an extra cast before the param all-gather. However, in some cases the distributed optimizer will use the same buffer for the master params and the param all-gather, e.g. when the all-gather dtype matches the master params or when doing FP8 all-gathers.

This PR adds a check to avoid writing to the same buffer twice. I'm not sure how much this issue is already mitigated by caching, so performance evaluations are in progress.

timmoon10 commented 7 months ago

The performance impact is small when I run on an H100 (<5%), so I think we can close this for now.