I benchmarked for batch_size, head_size, dim, seq_len = (4, 8, 64, 16384), the runtime is 3.24ms. However, for batch_size, head_size, dim, seq_len = (4, 8, 128, 16384), the runtime is only 2.03ms.
This suggests that the code may have more room to optimize.
Hi,