ggerganov / ggml

Tensor library for machine learning
MIT License
11.3k stars 1.05k forks source link

reduce for kqmax_new_j is unnecessary #1032

Open mahorozte opened 2 days ago

mahorozte commented 2 days ago

using this patch,the performance will increase about 1%-2% ,testing in A800

test-backend-ops -o FLASH_ATTN_EXT -b CUDA0 perf

i am do some trick to letting nb=1,2,3,7 will using flash_attn_vec_ext_f16(because A800 is capable for wmma) just for the eval performance

origin: FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 90112 runs - 11.51 us/run - 4.19 MFLOP/run - 364.41 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 73728 runs - 14.38 us/run - 8.39 MFLOP/run - 583.51 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 21.18 us/run - 12.58 MFLOP/run - 594.02 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 30654 runs - 34.17 us/run - 29.36 MFLOP/run - 859.30 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 57344 runs - 19.03 us/run - 8.39 MFLOP/run - 440.74 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 41727 runs - 25.04 us/run - 16.78 MFLOP/run - 670.07 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 27818 runs - 36.84 us/run - 25.17 MFLOP/run - 683.06 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 15327 runs - 67.79 us/run - 58.72 MFLOP/run - 866.21 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 81920 runs - 12.68 us/run - 4.19 MFLOP/run - 330.79 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 65536 runs - 16.16 us/run - 8.39 MFLOP/run - 519.10 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 22.11 us/run - 12.58 MFLOP/run - 569.02 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 27248 runs - 37.08 us/run - 29.36 MFLOP/run - 791.89 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 65536 runs - 16.47 us/run - 8.39 MFLOP/run - 509.47 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 21.93 us/run - 16.78 MFLOP/run - 765.03 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 31792 runs - 35.48 us/run - 25.17 MFLOP/run - 709.22 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 17030 runs - 60.34 us/run - 58.72 MFLOP/run - 973.20 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 90112 runs - 11.23 us/run - 4.19 MFLOP/run - 373.53 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 73728 runs - 13.85 us/run - 8.39 MFLOP/run - 605.60 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 55636 runs - 20.40 us/run - 12.58 MFLOP/run - 616.89 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 27248 runs - 40.79 us/run - 29.36 MFLOP/run - 719.80 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 57344 runs - 18.64 us/run - 8.39 MFLOP/run - 450.01 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 23.73 us/run - 16.78 MFLOP/run - 707.06 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 31792 runs - 34.49 us/run - 25.17 MFLOP/run - 729.75 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 15327 runs - 66.55 us/run - 58.72 MFLOP/run - 882.32 GFLOPS

apply this patch: FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 90112 runs - 11.14 us/run - 4.19 MFLOP/run - 376.67 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 73728 runs - 14.02 us/run - 8.39 MFLOP/run - 598.41 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 55636 runs - 20.66 us/run - 12.58 MFLOP/run - 609.01 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 30654 runs - 33.68 us/run - 29.36 MFLOP/run - 871.69 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 57344 runs - 18.82 us/run - 8.39 MFLOP/run - 445.67 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 41727 runs - 24.57 us/run - 16.78 MFLOP/run - 682.88 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 27818 runs - 36.41 us/run - 25.17 MFLOP/run - 691.19 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=1,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 15327 runs - 66.20 us/run - 58.72 MFLOP/run - 887.06 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 81920 runs - 12.70 us/run - 4.19 MFLOP/run - 330.27 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 65536 runs - 15.85 us/run - 8.39 MFLOP/run - 529.23 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 21.69 us/run - 12.58 MFLOP/run - 580.05 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 27248 runs - 36.73 us/run - 29.36 MFLOP/run - 799.45 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 65536 runs - 16.19 us/run - 8.39 MFLOP/run - 518.10 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 21.47 us/run - 16.78 MFLOP/run - 781.60 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 31792 runs - 34.47 us/run - 25.17 MFLOP/run - 730.00 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=1,max_bias=8.000000,logit_softcap=0.000000,type_KV=f16): 17030 runs - 59.55 us/run - 58.72 MFLOP/run - 986.15 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 98304 runs - 10.93 us/run - 4.19 MFLOP/run - 383.64 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 81920 runs - 13.52 us/run - 8.39 MFLOP/run - 620.46 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 55636 runs - 19.85 us/run - 12.58 MFLOP/run - 633.98 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=7,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 27248 runs - 40.10 us/run - 29.36 MFLOP/run - 732.12 GFLOPS

FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 57344 runs - 18.36 us/run - 8.39 MFLOP/run - 456.96 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 47688 runs - 23.40 us/run - 16.78 MFLOP/run - 716.87 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=3,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 31792 runs - 33.66 us/run - 25.17 MFLOP/run - 747.69 GFLOPS FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=7,mask=0,max_bias=0.000000,logit_softcap=0.000000,type_KV=f16): 15327 runs - 65.55 us/run - 58.72 MFLOP/run - 895.82 GFLOPS