uwsampl / SparseTIR

SparseTIR: Sparse Tensor Compiler for Deep Learning
https://sampl.cs.washington.edu/SparseTIR/
Apache License 2.0
131 stars 14 forks source link

[Discussion] How to efficiently handle branches in the axes dependency tree for sparse buffers. #69

Open yzh119 opened 2 years ago

yzh119 commented 2 years ago

Problem

Our current format annotations do not efficiently handle branches in the axes dependency tree, for example, the [irregular batched-GEMM operator] has the following dependency trees:

"""
     B
   / | \
  I  J  K

X: (B, I, K)
Y: (B, K, J)
Z: (B, I, J)
"""

B = T.dense_fixed(batch_size, "int32")
I = T.dense_variable(B, (m, nnz_I), indptr_I, "int32")
J = T.dense_variable(B, (n, nnz_J), indptr_J, "int32")
K = T.dense_variable(B, (k, nnz_K), indptr_K, "int32")

X = T.match_sparse_buffer(x, (B, I, K), "float32")
Y = T.match_sparse_buffer(y, (B, K, J), "float32")
Z = T.match_sparse_buffer(z, (B, I, J), "float32")

with T.iter([B, I, J, K], "SSSR", "irregular-batched-gemm") as [b, i, j, k]:
    with T.init():
        Z[b, i, j] = T.float32(0)
    Z[b, i, j] = X[b, i, k] * Y[b, k, j] 

The efficient indexing of X/Y/Z requires auxiliary buffers such as indptr_IK, indptr_KJ and indptr_IJ. But currently, SparseTIR does not provide such an interface.

Proposals

Let take B: (B, I, K) as an example:

Alternative 1: Create a new axis IK that follows I to replace K

IK = T.dense_variable(I, ...)

# before lowering
X[i, k]

# after lowering:
x[indptr_ik[indptr_i[b] + i] + k]

Alternative 2: Insert a bridge axis IK that flattens I and K

IK = T.flatten([I, K], ...)

# before
X[i, k]

# after lowering
X[indptr_ik[b] + i * (indptr_k[b + 1] - indptr_k[b]) + k]