uwsampl / SparseTIR

SparseTIR: Sparse Tensor Compiler for Deep Learning
https://sampl.cs.washington.edu/SparseTIR/
Apache License 2.0
131 stars 13 forks source link

[Tracking Issue] Inherit block iter_var from parent blocks #71

Open yzh119 opened 1 year ago

yzh119 commented 1 year ago

Pitch

Sparse Iteration Lowering pass would not create inherit block iter_vars from outer blocks. This is okay but does not follow the design principle of TensorIR that blocks should be pluggable (if we do not inherit outer block iter_vars, then this block relies on its context and cannot be compute_at to an arbitrary block).

# lowered nested block structure in previous design, we can use `vi` inside "inner" block, however, we cannot `compute_at` "inner" outside "outer" because information regarding "vi" is lost.
for i in range(20):
    with block("outer"):
        vi = T.axis.spatial(i, 20)
        for j in range(indptr[i + 1] - indptr[i]):
        with block("inner"):
            vj = T.axis.spatial(j, 10)
            ...

# inherit `vi` from "outer" block in inner block, we can `compute_at` "inner" block outside "outer" and `vi_1` would be re-bound to other iter values.
for i in range(20):
    with block("outer"):
        vi = T.axis.spatial(i, 20)
        for j in range(indptr[i + 1] - indptr[i]):
        with block("inner"):
            vi_1 = T.axis.spatial(vi, 20)
            vj = T.axis.spatial(j, 10)
            ...

This design also results in bugs such as #60 : if we carry all outer block iter_vars, we do not need to consider iter_vars in outer blocks in cache read/write.

Milestone