Open syncdoth opened 1 year ago
The benchmark is under float32
, and the reduce only valid for float32
.
If D=128
, i think the advantage only appear then the S > 100_00
$\approx D^2$.
However, I think it is possible to use the reduce mode
directly in float32 if the $\gamma$ is nearly to $\approx 1$.
In float16 and bfloat16, no benefit gain for current code. (2023.10.16) In float16 and bfloat16, those three ways are not even consistency for current code. (2023.10.16).
So far, I recommend just use huggingface accelerate
and bfloat16
and origin way
to train the model.
Under my testing, bfloat16
can directly give 1.5 times speed up in 3090 and (should be large in A100
H=16,B=1
Float32
S | D1 | D2 | e1 | e2 | fast | reduce | origin | speed_up_fast | speed_up_reduce |
---|---|---|---|---|---|---|---|---|---|
1024 | 32 | 32 | 2.62E-10 | 3.54E-10 | 0.00214 | 0.00108 | 0.00115 | 0.539 | 1.073 |
1024 | 64 | 64 | 1.93E-10 | 2.55E-10 | 0.00643 | 0.00145 | 0.00103 | 0.161 | 0.713 |
1024 | 128 | 128 | 1.47E-10 | 1.89E-10 | 0.02507 | 0.00154 | 0.00114 | 0.046 | 0.741 |
2048 | 32 | 32 | 1.68E-10 | 2.42E-10 | 0.00386 | 0.00214 | 0.00331 | 0.858 | 1.547 |
2048 | 64 | 64 | 1.21E-10 | 1.72E-10 | 0.01462 | 0.00291 | 0.00353 | 0.242 | 1.215 |
2048 | 128 | 128 | 9.03E-11 | 1.26E-10 | 0.05813 | 0.00822 | 0.00397 | 0.068 | 0.483 |
4096 | 32 | 32 | 1.08E-10 | 0.00940 | 0.00451 | 0.01235 | 1.314 | 2.741 | |
4096 | 64 | 64 | 7.83E-11 | 0.03623 | 0.00689 | 0.01325 | 0.366 | 1.923 | |
4096 | 128 | 128 | 5.70E-11 | 0.14264 | 0.01686 | 0.01498 | 0.105 | 0.888 |
BFloat16
S | D1 | D2 | e1 | e2 | fast | reduce | origin | speed_up_fast | speed_up_reduce |
---|---|---|---|---|---|---|---|---|---|
1024 | 32 | 32 | 1.21E-04 | 1.55E-05 | 0.00220 | 0.00110 | 0.00044 | 0.199 | 0.399 |
1024 | 64 | 64 | 8.30E-05 | 1.09E-05 | 0.00706 | 0.00137 | 0.00046 | 0.065 | 0.333 |
1024 | 128 | 128 | 5.79E-05 | 7.69E-06 | 0.02724 | 0.00309 | 0.00045 | 0.016 | 0.145 |
2048 | 32 | 32 | 1.02E-04 | 1.02E-05 | 0.00387 | 0.00195 | 0.00154 | 0.398 | 0.791 |
2048 | 64 | 64 | 7.39E-05 | 7.15E-06 | 0.01476 | 0.00273 | 0.00161 | 0.109 | 0.591 |
2048 | 128 | 128 | 5.25E-05 | 5.10E-06 | 0.05786 | 0.00620 | 0.00155 | 0.027 | 0.251 |
4096 | 32 | 32 | 8.87E-05 | 0.00862 | 0.00398 | 0.00566 | 0.657 | 1.421 | |
4096 | 64 | 64 | 6.29E-05 | 0.03377 | 0.00551 | 0.00575 | 0.17 | 1.044 | |
4096 | 128 | 128 | 4.48E-05 | 0.13161 | 0.01343 | 0.00595 | 0.045 | 0.443 |
It is also wired the discounted_torchsum
method is slower than reduced
method.
Thanks for a great work!
I am just wondering what would happen if we had higher D. This is because the RetNet configs that you can obtain from the torchscale (and also mine) have typically
D=128
(such as RetNet-13b, which hasembed_dim=5120, num_head=40
), and I would like to train with eitherS=2048 or 4096
. How much speed up / memory reduction could we expect?