Oneflow-Inc / libai

LiBai(李白): A Toolbox for Large-Scale Distributed Parallel Training
https://libai.readthedocs.io
Apache License 2.0
391 stars 55 forks source link

Refine mt5 #410

Closed xiezipeng-ML closed 2 years ago

xiezipeng-ML commented 2 years ago

https://github.com/Oneflow-Inc/libai/issues/406#issuecomment-1292151939

上文中说的 SelfAttention 中的未知 SendRecv 是必要的,它在代码这里,megatron 这里没有的原因是算法不一样,megatron 里面的 t5 没有这个 position_bias。

position_bias 这里 position_bias (S(0), B) 要与 attention_scores (S(0), S(1)) 做计算,需要做一个 (S(0), B) -> (S(0), S(1)),目前 2d SBP 里面是用 SendRecv 实现的,但可以用 SameDim0AllScatter 来实现(没有通信开销)。

但上述 (S(0), B) -> (S(0), S(1)) 的转换不用每一层 layer 都做,因为 position_bias 是在 layer 0 通过 compute_bias 计算出来的,后面的所有 layer 使用的都是 layer 0 的 position_bias,所以该转换只需要做一次。而 position_bias 在与 attention_scores 相加之前,需要先与 attention_mask (S(0), B) 相加(见这里),加完之后 position_bias sbp 也变为了 (S(0), B)。

我们只需要将 position_bias = position_bias.to_global(placement=attention_scores.placement) 这行代码移动到前面的 if 作用域之内,position_bias = position_bias + (1 - attention_mask) * -1000 之后,即可使 (S(0), B) -> (S(0), S(1)) 的转换只做1次。

根据wenxiao的这个refine一下mt5的compute_bias中的to_global位置

strint commented 2 years ago

需要做一个测评,看下性能变化的指标,再合并

xiezipeng-ML commented 2 years ago
# num_layers=6
修改后:
显存:2459MiB
[10/28 06:00:17 libai]: >>> done with building model. Building time: 0.902 seconds
[10/28 06:03:40 lb.utils.events]:  eta: 0:51:33  iteration: 849/24000  consumed_samples: 6800  total_loss: 3.505  time: 0.1323 s/iter  data_time: 0.0114 s/iter total_throughput: 60.49 samples/s lr: 8.49e-05

修改前:
显存:2459MiB
done with building model. Building time: 1.366 seconds
[10/28 05:59:29 lb.utils.events]:  eta: 0:51:30  iteration: 849/24000  consumed_samples: 6800  total_loss: 3.513  time: 0.1322 s/iter  data_time: 0.0118 s/iter total_throughput: 60.51 samples/s lr: 8.49e-05  

# num_layers=12
修改后:
3587MiB
[10/28 06:06:15 libai]: >>> done with building model. Building time: 1.312 seconds
[10/28 06:12:40 lb.utils.events]:  eta: 1:37:56  iteration: 849/24000  consumed_samples: 6800  total_loss: 3.514  time: 0.2518 s/iter  data_time: 0.0117 s/iter total_throughput: 31.77 samples/s lr: 8.49e-05 

修改前:
3587MiB
[10/28 06:13:43 libai]: >>> done with building model. Building time: 1.555 seconds
[10/28 06:20:14 lb.utils.events]:  eta: 1:39:01  iteration: 849/24000  consumed_samples: 6800  total_loss: 3.514  time: 0.2579 s/iter  data_time: 0.0176 s/iter total_throughput: 31.03 samples/s lr: 8.49e-05 

@strint 可能这个改变的影响有限