Closed Panlichen closed 1 year ago
我用的是vit_base_patch16_224,进行纯tensor并行的分布式训练,发现在4卡的情况下 ,计算图里一共包含了49个AllReduce,其中1个是聚合loss的,另外48个应该就是按照megatron论文里的描述,每个transformer层前后向共需要4个AllReduce,所以是来自12个transformer层;到了8卡的时候,计算图里的集合通信增加了AllGather和ReduceScatter,共有85个集合通信,按照计算图上的拓扑序,是3个阶段:(24个AllGather+1个AllReduce);(24个AllReduce);([AllGather-ReduceScatter-AllReduce] × 12),请问这样的变化是怎么发生的,从4卡到8卡,模型的划分方式发生了什么样的变化,导致引入了新的通信算子?
使用的配置文件中cfg.num_heads=12,不可以被8整除,同时没有触发非均等划分。 改为cfg.num_heads=16,这个问题就解决了。
不过非均等划分没有启用不知道算不算bug。
来自 liyipeng :真正的原因找到了,非均衡切割还是在工作的,但是,这里有个reshape op,对于reshape op由于数据连续性的问题只支持组头被整除的切分,而8不整除12,所以这里的sbp不支持S2
我用的是vit_base_patch16_224,进行纯tensor并行的分布式训练,发现在4卡的情况下 ,计算图里一共包含了49个AllReduce,其中1个是聚合loss的,另外48个应该就是按照megatron论文里的描述,每个transformer层前后向共需要4个AllReduce,所以是来自12个transformer层;到了8卡的时候,计算图里的集合通信增加了AllGather和ReduceScatter,共有85个集合通信,按照计算图上的拓扑序,是3个阶段:(24个AllGather+1个AllReduce);(24个AllReduce);([AllGather-ReduceScatter-AllReduce] × 12),请问这样的变化是怎么发生的,从4卡到8卡,模型的划分方式发生了什么样的变化,导致引入了新的通信算子?