Our current format annotations do not efficiently handle branches in the axes dependency tree, for example, the [irregular batched-GEMM operator] has the following dependency trees:
"""
B
/ | \
I J K
X: (B, I, K)
Y: (B, K, J)
Z: (B, I, J)
"""
B = T.dense_fixed(batch_size, "int32")
I = T.dense_variable(B, (m, nnz_I), indptr_I, "int32")
J = T.dense_variable(B, (n, nnz_J), indptr_J, "int32")
K = T.dense_variable(B, (k, nnz_K), indptr_K, "int32")
X = T.match_sparse_buffer(x, (B, I, K), "float32")
Y = T.match_sparse_buffer(y, (B, K, J), "float32")
Z = T.match_sparse_buffer(z, (B, I, J), "float32")
with T.iter([B, I, J, K], "SSSR", "irregular-batched-gemm") as [b, i, j, k]:
with T.init():
Z[b, i, j] = T.float32(0)
Z[b, i, j] = X[b, i, k] * Y[b, k, j]
The efficient indexing of X/Y/Z requires auxiliary buffers such as indptr_IK, indptr_KJ and indptr_IJ. But currently, SparseTIR does not provide such an interface.
Proposals
Let take B: (B, I, K) as an example:
Alternative 1: Create a new axis IK that follows I to replace K
IK = T.dense_variable(I, ...)
# before lowering
X[i, k]
# after lowering:
x[indptr_ik[indptr_i[b] + i] + k]
Alternative 2: Insert a bridge axis IK that flattens I and K
IK = T.flatten([I, K], ...)
# before
X[i, k]
# after lowering
X[indptr_ik[b] + i * (indptr_k[b + 1] - indptr_k[b]) + k]
Problem
Our current format annotations do not efficiently handle branches in the axes dependency tree, for example, the [irregular batched-GEMM operator] has the following dependency trees:
The efficient indexing of
X
/Y
/Z
requires auxiliary buffers such asindptr_IK
,indptr_KJ
andindptr_IJ
. But currently, SparseTIR does not provide such an interface.Proposals
Let take
B: (B, I, K)
as an example:Alternative 1: Create a new axis
IK
that followsI
to replaceK
Alternative 2: Insert a bridge axis
IK
that flattensI
andK