Could you explain the detailed implementation of clustered_sparse_dot_product in _topk_attention function of ImprovedClusteredAttention class.
I feel a little confused about how to compute QK with the below code snippet
class ImprovedClusteredAttention(Module):
......
def _topk_attention(self, Q, K, V,
clusters, counts,
topk, topk_values,
A_bottomk, softmax_temp,
query_lengths):
N, H, L, E = Q.shape
_, _, S, _ = K.shape
_, _, C, k = topk.shape
# We need to pass the output tensor to initialize to 0
QK = clustered_sparse_dot_product(
Q, K, topk,
clusters, counts,
query_lengths._lengths.int()
)
......
Could you explain the detailed implementation of
clustered_sparse_dot_product
in_topk_attention
function ofImprovedClusteredAttention
class. I feel a little confused about how to computeQK
with the below code snippet