I really love this project and the accompanying blogpost, so thanks! I've reimplemented some of the inference techniques to speed up an implementation of Whisper that I am using. I had a few questions the attention kernels, as they have been giving me some some performance related issues.
By adding print statements, I can see that during the attention calculation (not including prefilling) the shapes are essentially:
k - (bs, n_heads, max_seq_len, d_head)
v - (bs, n_heads, max_seq_len, d_head)
I understand that max_seq_len is there because of the static KV cache implementation. My understanding is that due to the attention mask, the F.scaled_dot_product_attention combined with torch.compile should be able to tell that it doesn't need to calculate the attention over the entire max_seq_len. In my case however, I've found that the max_seq_len value has a big effect on the inference speed, which suggests to me that the full attention (over the entire max_seq_len context) is being performed on every iteration. This is vastly reduced when using the following context manager, as is done in generate.py:
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True)
If I exclude this, I am seeing a 3x reduction in tok/s. Does this make sense? Or it is a sign I've implemented something wrong? Even with this context manager, I see a significant (50%+) increase in tok/s when I reduce the context length from 4096 to 2048 or 1024.
Thanks in advance. If it helps, here is my cuda graph friendly Whisper implementation using a static KV cache:
I really love this project and the accompanying blogpost, so thanks! I've reimplemented some of the inference techniques to speed up an implementation of Whisper that I am using. I had a few questions the attention kernels, as they have been giving me some some performance related issues.
By adding print statements, I can see that during the attention calculation (not including prefilling) the shapes are essentially:
I understand that max_seq_len is there because of the static KV cache implementation. My understanding is that due to the attention mask, the F.scaled_dot_product_attention combined with torch.compile should be able to tell that it doesn't need to calculate the attention over the entire max_seq_len. In my case however, I've found that the max_seq_len value has a big effect on the inference speed, which suggests to me that the full attention (over the entire max_seq_len context) is being performed on every iteration. This is vastly reduced when using the following context manager, as is done in generate.py:
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True)
If I exclude this, I am seeing a 3x reduction in tok/s. Does this make sense? Or it is a sign I've implemented something wrong? Even with this context manager, I see a significant (50%+) increase in tok/s when I reduce the context length from 4096 to 2048 or 1024.
Thanks in advance. If it helps, here is my cuda graph friendly Whisper implementation using a static KV cache: