uwsampl / SparseTIR

SparseTIR: Sparse Tensor Compiler for Deep Learning
https://sampl.cs.washington.edu/SparseTIR/
Apache License 2.0
131 stars 14 forks source link

[Tracking Issue] Redesign the internal storage of sparse buffers. #65

Open yzh119 opened 2 years ago

yzh119 commented 2 years ago

Pitch

In the current design of SparseTIR, the internal storage of the value field of sparse buffers is 1D and the sparse buffer lowering pass would flatten every sparse buffer to 1-dimensional.

However, such a design is not necessary because we only want to flatten variable axes, while keeping the dimensions of fixed axes. More specifically:

# before flattening
I = T.dense_fixed(m, "int32")
J = T.dense_fixed(I, (n, nnz), (indptr, indices), "int32")
K = T.dense_fixed(k, "int32")
A = T.match_sparse_buffer(a, (I, J, K), "float32")

# after flattening (previous behavior)
A = T.match_buffer(a, (nnz * k,), "float32")

# after flattening (new behavior)
A = T.match_buffer(a, (nnz, k), "float32")

we should only flatten a "variable" axes chain and leave other axes in their original form, such design can help us reuse the schedules for "dense" parts of the tensor program when integrated with relax, the graph-level IR in TVM stack.

More specifically, sparse_fixed axes do not need to be flattened because they are fixed, the following is a case of a 8x4 sparse matrix with 2:4 sparsity:

I = T.dense_fixed(8, "int32")
J = T.sparse_fixed(I, (4, 2), indices, "int32")
A = T.match_sparse_buffer(a, (I, J), "float32")

# after flattening
A = T.match_sparse_buffer(a, (8, 2), "float32")

after flattening, it becomes a 8x2 compact dense matrix with shape 8x2.

The new design also enables us to move some of the schedules (e.g. 2:4 sparse tensor core tensorization) to stage III IR.

Checklist