Main reason: 在使用 torch dpp 多卡的时候,BatchNorm 层默认值使用自己卡上的数据计算均值和方差然后进行归一化,导致不同卡上计算得到的均值和方差不同,可能会使训练变得不稳定。
Checklist 检查下面各项是否完成
Please feel free to remove inapplicable items for your PR.
[x] The PR title starts with [$CATEGORY] (例如[bugfix]修复bug,[new]添加新功能,[test]修改测试,[rm]删除旧代码)
[x] Changes are complete (i.e. I finished coding on this PR) 修改完成才提PR
[x] All changes have test coverage 修改的部分顺利通过测试。对于fastnlp/fastnlp/的修改,测试代码必须提供在fastnlp/test/。
[x] Code is well-documented 注释写好,API文档会从注释中抽取
[x] To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change 修改导致例子或tutorial有变化,请找核心开发人员
Description:torch ddp多卡的情况下保证 Batch Norm 使用所有卡上的数据共同计算均值和方差。
Main reason: 在使用 torch dpp 多卡的时候,BatchNorm 层默认值使用自己卡上的数据计算均值和方差然后进行归一化,导致不同卡上计算得到的均值和方差不同,可能会使训练变得不稳定。
Checklist 检查下面各项是否完成
Please feel free to remove inapplicable items for your PR.
Changes: 逐项描述修改的内容