ROCm / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
131 stars 41 forks source link

Optimization based on profiling for forward #10

Closed guangzlu closed 1 year ago

guangzlu commented 1 year ago

Added non padding conditions and non dropout conditions to improve performance. Now the perfmorance can reach up to 70TFLOPS in unpadding conditions. <html xmlns:m="http://schemas.microsoft.com/office/2004/12/omml" xmlns="http://www.w3.org/TR/REC-html40">

  |   |   |   |   |   |   |   | previous | now | previous | now | improvement of TFLOPS -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- dtype | batch size | embedding size | nheads | embedding dim | seqlen | casual | dropout | mi250 fwd(ms) | mi250 fwd(ms) | TFLOPS/mi250 | TFLOPS/mi250 | (now/previous)-1 torch.float16 | 32 | 2048 | 16 | 128 | 512 | True | 0 | 1.48 | 1.28 | 23.21603944 | 26.8435456 | 0.15625 torch.float16 | 32 | 2048 | 16 | 128 | 512 | False | 0 | 1.5 | 1.26 | 45.81298449 | 54.53926725 | 0.19047619 torch.float16 | 16 | 2048 | 16 | 128 | 1024 | True | 0 | 2.47 | 2.21 | 27.8216505 | 31.09478585 | 0.117647059 torch.float16 | 16 | 2048 | 16 | 128 | 1024 | False | 0 | 2.52 | 2.09 | 54.53926725 | 65.76026482 | 0.205741627 torch.float16 | 8 | 2048 | 16 | 128 | 2048 | True | 0 | 3.6 | 3.06 | 38.17748708 | 44.91469068 | 0.176470588 torch.float16 | 8 | 2048 | 16 | 128 | 2048 | False | 0 | 4.63 | 3.81 | 59.36887839 | 72.14643227 | 0.215223097 torch.float16 | 4 | 2048 | 16 | 128 | 4096 | True | 0 | 5.93 | 5.06 | 46.35377857 | 54.32369702 | 0.171936759 torch.float16 | 4 | 2048 | 16 | 128 | 4096 | False | 0 | 8.97 | 7.29 | 61.28827357 | 75.41232015 | 0.230452675 torch.float16 | 2 | 2048 | 16 | 128 | 8192 | True | 0 | 10.65 | 9.13 | 51.62026421 | 60.21421839 | 0.166484118 torch.float16 | 2 | 2048 | 16 | 128 | 8192 | False | 0 | 17.8 | 14.31 | 61.77031617 | 76.83519411 | 0.243885395 torch.float16 | 1 | 2048 | 16 | 128 | 16384 | True | 0 | 20.33 | 17.36 | 54.08320845 | 63.33592326 | 0.171082949 torch.float16 | 1 | 2048 | 16 | 128 | 16384 | False | 0 | 35.54 | 28.53 | 61.8745992 | 77.07757643 | 0.245706274 torch.float16 | 32 | 2048 | 32 | 64 | 512 | True | 0 | 1.47 | 1.33 | 23.37397168 | 25.83438975 | 0.105263158 torch.float16 | 32 | 2048 | 32 | 64 | 512 | False | 0 | 1.47 | 1.24 | 46.74794336 | 55.41893285 | 0.185483871 torch.float16 | 16 | 2048 | 32 | 64 | 1024 | True | 0 | 2.45 | 2.21 | 28.04876601 | 31.09478585 | 0.108597285 torch.float16 | 16 | 2048 | 32 | 64 | 1024 | False | 0 | 2.53 | 2.00 | 54.32369702 | 68.71947674 | 0.265 torch.float16 | 8 | 2048 | 32 | 64 | 2048 | True | 0 | 3.59 | 3.16 | 38.28383105 | 43.49333971 | 0.136075949 torch.float16 | 8 | 2048 | 32 | 64 | 2048 | False | 0 | 4.75 | 3.76 | 57.86903304 | 73.10582631 | 0.263297872 torch.float16 | 4 | 2048 | 32 | 64 | 4096 | True | 0 | 5.96 | 5.23 | 46.12045419 | 52.5579172 | 0.13957935 torch.float16 | 4 | 2048 | 32 | 64 | 4096 | False | 0 | 9.26 | 7.26 | 59.36887839 | 75.72394131 | 0.275482094 torch.float16 | 2 | 2048 | 32 | 64 | 8192 | True | 0 | 10.69 | 9.38 | 51.42711075 | 58.60936182 | 0.139658849 torch.float16 | 2 | 2048 | 32 | 64 | 8192 | False | 0 | 18.29 | 14.32 | 60.11545258 | 76.78153825 | 0.277234637 torch.float16 | 1 | 2048 | 32 | 64 | 16384 | True | 0 | 20.27 | 17.82 | 54.24329688 | 61.70098921 | 0.137485971 torch.float16 | 1 | 2048 | 32 | 64 | 16384 | False | 0 | 36.44 | 28.58 | 60.34641206 | 76.94273112 | 0.275017495