In the current design of SparseTIR, the internal storage of the value field of sparse buffers is 1D and the sparse buffer lowering pass would flatten every sparse buffer to 1-dimensional.
However, such a design is not necessary because we only want to flatten variable axes, while keeping the dimensions of fixed axes. More specifically:
# before flattening
I = T.dense_fixed(m, "int32")
J = T.dense_fixed(I, (n, nnz), (indptr, indices), "int32")
K = T.dense_fixed(k, "int32")
A = T.match_sparse_buffer(a, (I, J, K), "float32")
# after flattening (previous behavior)
A = T.match_buffer(a, (nnz * k,), "float32")
# after flattening (new behavior)
A = T.match_buffer(a, (nnz, k), "float32")
we should only flatten a "variable" axes chain and leave other axes in their original form, such design can help us reuse the schedules for "dense" parts of the tensor program when integrated with relax, the graph-level IR in TVM stack.
More specifically, sparse_fixed axes do not need to be flattened because they are fixed, the following is a case of a 8x4 sparse matrix with 2:4 sparsity:
I = T.dense_fixed(8, "int32")
J = T.sparse_fixed(I, (4, 2), indices, "int32")
A = T.match_sparse_buffer(a, (I, J), "float32")
# after flattening
A = T.match_sparse_buffer(a, (8, 2), "float32")
after flattening, it becomes a 8x2 compact dense matrix with shape 8x2.
The new design also enables us to move some of the schedules (e.g. 2:4 sparse tensor core tensorization) to stage III IR.
Pitch
In the current design of SparseTIR, the internal storage of the value field of sparse buffers is 1D and the sparse buffer lowering pass would flatten every sparse buffer to 1-dimensional.
However, such a design is not necessary because we only want to flatten variable axes, while keeping the dimensions of fixed axes. More specifically:
we should only flatten a "variable" axes chain and leave other axes in their original form, such design can help us reuse the schedules for "dense" parts of the tensor program when integrated with relax, the graph-level IR in TVM stack.
More specifically,
sparse_fixed
axes do not need to be flattened because they are fixed, the following is a case of a 8x4 sparse matrix with 2:4 sparsity:after flattening, it becomes a 8x2 compact dense matrix with shape 8x2.
The new design also enables us to move some of the schedules (e.g. 2:4 sparse tensor core tensorization) to stage III IR.
Checklist