import torch
def allreduce(data: list):
# 获取需通信的进程总数
world_size = len(data)
# 将所有张量移动到相同的设备上
device = data[0].device
for i in range(1, world_size):
data[i] = data[i].to(device)
# 归约阶段
for i in range(world_size - 1):
if i % 2 == 0:
src = i + 1
dest = i
else:
src = i
dest = i + 1
torch.cuda.synchronize(device) # 同步设备上的计算
data[dest] += data[src]
# 广播阶段
for i in range(1, world_size):
src = 0
dest = i
torch.cuda.synchronize(device) # 同步设备上的计算
data[dest] = data[src].to(data[dest].device)
# 测试
data = [torch.ones((1, 2), device=d2l.try_gpu(i)) * (i + 1) for i in range(2)]
print('allreduce之前:\n', data[0], '\n', data[1])
allreduce(data)
print('allreduce之后:\n', data[0], '\n', data[1])
练习12.5.3
实现一个更高效的 allreduce 函数用于在不同的 GPU 上聚合不同的参数?为什么这样的效率更高?
解答代码中给出的函数只在例子中的data,即包含两项数据时,返回正确结果 如果data有更多项,如
输出的结果仍然只相当于前两项的求和与广播
个人认为函数中归约阶段的代码不完善,没有正确完成累加data中所有数据项的任务