veya2ztn / fast_retention

Speed up Parallel Retention about 2x times
2 stars 1 forks source link

Question about larger D #2

Open syncdoth opened 1 year ago

syncdoth commented 1 year ago

Thanks for a great work!

I am just wondering what would happen if we had higher D. This is because the RetNet configs that you can obtain from the torchscale (and also mine) have typically D=128 (such as RetNet-13b, which has embed_dim=5120, num_head=40), and I would like to train with either S=2048 or 4096. How much speed up / memory reduction could we expect?

veya2ztn commented 1 year ago

The benchmark is under float32, and the reduce only valid for float32 . If D=128, i think the advantage only appear then the S > 100_00 $\approx D^2$.
However, I think it is possible to use the reduce mode directly in float32 if the $\gamma$ is nearly to $\approx 1$.

In float16 and bfloat16, no benefit gain for current code. (2023.10.16) In float16 and bfloat16, those three ways are not even consistency for current code. (2023.10.16).

So far, I recommend just use huggingface accelerate and bfloat16 and origin way to train the model.

veya2ztn commented 1 year ago

Under my testing, bfloat16 can directly give 1.5 times speed up in 3090 and (should be large in A100

veya2ztn commented 1 year ago

H=16,B=1

Float32

S D1 D2 e1 e2 fast reduce origin speed_up_fast speed_up_reduce
1024 32 32 2.62E-10 3.54E-10 0.00214 0.00108 0.00115 0.539 1.073
1024 64 64 1.93E-10 2.55E-10 0.00643 0.00145 0.00103 0.161 0.713
1024 128 128 1.47E-10 1.89E-10 0.02507 0.00154 0.00114 0.046 0.741
2048 32 32 1.68E-10 2.42E-10 0.00386 0.00214 0.00331 0.858 1.547
2048 64 64 1.21E-10 1.72E-10 0.01462 0.00291 0.00353 0.242 1.215
2048 128 128 9.03E-11 1.26E-10 0.05813 0.00822 0.00397 0.068 0.483
4096 32 32 1.08E-10 0.00940 0.00451 0.01235 1.314 2.741
4096 64 64 7.83E-11 0.03623 0.00689 0.01325 0.366 1.923
4096 128 128 5.70E-11 0.14264 0.01686 0.01498 0.105 0.888

BFloat16

S D1 D2 e1 e2 fast reduce origin speed_up_fast speed_up_reduce
1024 32 32 1.21E-04 1.55E-05 0.00220 0.00110 0.00044 0.199 0.399
1024 64 64 8.30E-05 1.09E-05 0.00706 0.00137 0.00046 0.065 0.333
1024 128 128 5.79E-05 7.69E-06 0.02724 0.00309 0.00045 0.016 0.145
2048 32 32 1.02E-04 1.02E-05 0.00387 0.00195 0.00154 0.398 0.791
2048 64 64 7.39E-05 7.15E-06 0.01476 0.00273 0.00161 0.109 0.591
2048 128 128 5.25E-05 5.10E-06 0.05786 0.00620 0.00155 0.027 0.251
4096 32 32 8.87E-05 0.00862 0.00398 0.00566 0.657 1.421
4096 64 64 6.29E-05 0.03377 0.00551 0.00575 0.17 1.044
4096 128 128 4.48E-05 0.13161 0.01343 0.00595 0.045 0.443

It is also wired the discounted_torchsum method is slower than reduced method.