The distributed Adam kernel is implemented to simultaneously output to the master param buffer (usually FP32) and to the param all-gather buffer (usually BF16). This saves the cost of doing an extra cast before the param all-gather. However, in some cases the distributed optimizer will use the same buffer for the master params and the param all-gather, e.g. when the all-gather dtype matches the master params or when doing FP8 all-gathers.
This PR adds a check to avoid writing to the same buffer twice. I'm not sure how much this issue is already mitigated by caching, so performance evaluations are in progress.
The distributed Adam kernel is implemented to simultaneously output to the master param buffer (usually FP32) and to the param all-gather buffer (usually BF16). This saves the cost of doing an extra cast before the param all-gather. However, in some cases the distributed optimizer will use the same buffer for the master params and the param all-gather, e.g. when the all-gather dtype matches the master params or when doing FP8 all-gathers.
This PR adds a check to avoid writing to the same buffer twice. I'm not sure how much this issue is already mitigated by caching, so performance evaluations are in progress.