Currently we skip all-reduce if it is for gradient accumulation and rewritten (call them grad-acc all-reduce). However, after that, such an all-reduce can be merged with all-reduce not for grad-acc. Skip the merged one results in incorrect outputs, we should identify grad-acc all-reduce and only allow them to merge with grad-acc all-reduce.
Currently we skip all-reduce if it is for gradient accumulation and rewritten (call them grad-acc all-reduce). However, after that, such an all-reduce can be merged with all-reduce not for grad-acc. Skip the merged one results in incorrect outputs, we should identify grad-acc all-reduce and only allow them to merge with grad-acc all-reduce.
A reproducible is: