When flash_attn_func is called, I see query_layer have different batch size with key_layer and value_layer. Since cat layer operation pads key_layer and value_layer with extra one row/colum, then after view operation, key_layer and value_layer should have larger batch size than query_layer. Is this intended and flash_attn_func support this kind of usage, or is it a bug?
When flash_attn_func is called, I see
query_layer
have different batch size withkey_layer
andvalue_layer
. Since cat layer operation padskey_layer
andvalue_layer
with extra one row/colum, then after view operation,key_layer
andvalue_layer
should have larger batch size thanquery_layer
. Is this intended andflash_attn_func
support this kind of usage, or is it a bug?@yzy-thu cc.