ROCm / composable_kernel

Composable Kernel: Performance Portable Programming Model for Machine Learning Tensor Operators
https://rocm.docs.amd.com/projects/composable_kernel/en/latest/
Other
297 stars 113 forks source link

Hacking ck_tile fmha Dropout facility #1344

Closed qianfengz closed 3 months ago

qianfengz commented 3 months ago

This PR makes two changes in ck_tile BlockDropout facility and forward kernel:

  1. Add NullBlockDropout, which is used to reduce the vgprs/sgprs when kHasDropout is false (pure inference path)
  2. Change in BlockDropout::Run() to reduce the conditional checking for is_store_randval
#before the change (last second column shows exec time)
"f16 384-197-1-88, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,65,12.697265625
"f16 384-197-1-64, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,36,9.234375
"f16 1024-197-1-64, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,89,24.625
"f16 32-197-16-80, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,94,15.390625
"f16 16-197-16-88, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,45,8.46484375
"f16 150-256-16-64, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,261,75.0
"f16 1-16384-16-40, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,3858,20.0
"f16 2-2048-8-128, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,167,8.0
"f16 16-128-16-256, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,46,16.0
"f16 16-1024-16-32, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,244,16.0
#after  the change (last second column shows exec time_
"f16 384-197-1-88, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,53,12.697265625
"f16 384-197-1-64, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,35,9.234375
"f16 1024-197-1-64, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,80,24.625
"f16 32-197-16-80, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,88,15.390625
"f16 16-197-16-88, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,38,8.46484375
"f16 150-256-16-64, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,252,75.0
"f16 1-16384-16-40, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,3746,20.0
"f16 2-2048-8-128, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,143,8.0
"f16 16-128-16-256, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,36,16.0
"f16 16-1024-16-32, p=0.3, BiasT=NoneType",attention (attn_bias=<class 'NoneType'>),1,ckF,,232,16.0