uwsampl / SparseTIR

SparseTIR: Sparse Tensor Compiler for Deep Learning
https://sampl.cs.washington.edu/SparseTIR/
Apache License 2.0
131 stars 14 forks source link

[Tracking Issue] Add an argument to determine whether to place init block inside inner block in `blockize` primitive #41

Closed yzh119 closed 2 years ago

yzh119 commented 2 years ago

Problem with current design

Currently, the blockize primitive's default behavior is to place the init block in the outside block when we try to blockize a reduction block. However, this might not be the desired behavior when the block iter vars in the outside block are all data parallel:

for io_1, ii_1, fo in T.grid(nnz_1, 4, feat_size):
    with T.block("rgcn-hetero-forward_10_o"):
        vr_1 = T.axis.spatial(1, 0)
        vio_1, vii_1 = T.axis.remap("SS", [io_1, ii_1])
        vj_1_o = T.axis.reduce(1, 0)
        vfo = T.axis.spatial(feat_size, fo)
        vfi_o = T.axis.reduce(1, 0)
        T.reads(II_1_indices[vr_1, vio_1, vii_1], A_1[vr_1, vio_1, vii_1, 0 : 2], W[mid_1[vr_1, vio_1], vfo, 0 : feat_size], mid_1[vr_1, vio_1], X[0 : n, 0 : feat_size], J_1_indices[vr_1, vio_1, vii_1, 0 : 2])
        T.writes(Y[II_1_indices[vr_1, vio_1, vii_1], vfo])
        with T.init():
            with T.block("rgcn-hetero-forward_10_init"):
                T.reads()
                T.writes(Y[II_1_indices[vr_1, vio_1, vii_1], vfo])
                Y[II_1_indices[vr_1, vio_1, vii_1], vfo] = T.float32(0)
        for j_1, fi in T.grid(2, feat_size):
            with T.block("rgcn-hetero-forward_10"):
                vj_1, vfi = T.axis.remap("RR", [j_1, fi])
                T.reads(Y[II_1_indices[vr_1, vio_1, vii_1], vfo], II_1_indices[vr_1, vio_1, vii_1], A_1[vr_1, vio_1, vii_1, vj_1], W[mid_1[vr_1, vio_1], vfo, vfi], mid_1[vr_1, vio_1], X[J_1_indices[vr_1, vio_1, vii_1, vj_1], vfi], J_1_indices[vr_1, vio_1, vii_1, vj_1])
                T.writes(Y[II_1_indices[vr_1, vio_1, vii_1], vfo])
                T.block_attr({"sparse":True})
                Y[II_1_indices[vr_1, vio_1, vii_1], vfo] = Y[II_1_indices[vr_1, vio_1, vii_1], vfo] + A_1[vr_1, vio_1, vii_1, vj_1] * W[mid_1[vr_1, vio_1], vfo, vfi] * X[J_1_indices[vr_1, vio_1, vii_1, vj_1], vfi]

In this case, the inner block is neither a local complete block nor a local reduction block, and we can not bind the loop surrounding the inner block to any physical threads.

Proposal

Add an extra argument inner_init which defaults to False. When set to True, blockize would check whether the outer loops are used in any reduction iter vars, if not, we create a new block starting from the given loop, and place init block inside the created inner block:

for io_1, ii_1, fo in T.grid(nnz_1, 4, feat_size):
    with T.block("rgcn-hetero-forward_10_o"):
        vr_1 = T.axis.spatial(1, 0)
        vio_1, vii_1 = T.axis.remap("SS", [io_1, ii_1])
        vfo = T.axis.spatial(feat_size, fo)
        T.reads(II_1_indices[vr_1, vio_1, vii_1], A_1[vr_1, vio_1, vii_1, 0 : 2], W[mid_1[vr_1, vio_1], vfo, 0 : feat_size], mid_1[vr_1, vio_1], X[0 : n, 0 : feat_size], J_1_indices[vr_1, vio_1, vii_1, 0 : 2])
        T.writes(Y[II_1_indices[vr_1, vio_1, vii_1], vfo])
        for j_1, fi in T.grid(2, feat_size):
            with T.block("rgcn-hetero-forward_10"):
                with T.init():
                    with T.block("rgcn-hetero-forward_10_init"):
                        T.reads()
                        T.writes(Y[II_1_indices[vr_1, vio_1, vii_1], vfo])
                        Y[II_1_indices[vr_1, vio_1, vii_1], vfo] = T.float32(0)
                vj_1, vfi = T.axis.remap("RR", [j_1, fi])
                T.reads(Y[II_1_indices[vr_1, vio_1, vii_1], vfo], II_1_indices[vr_1, vio_1, vii_1], A_1[vr_1, vio_1, vii_1, vj_1], W[mid_1[vr_1, vio_1], vfo, vfi], mid_1[vr_1, vio_1], X[J_1_indices[vr_1, vio_1, vii_1, vj_1], vfi], J_1_indices[vr_1, vio_1, vii_1, vj_1])
                T.writes(Y[II_1_indices[vr_1, vio_1, vii_1], vfo])
                T.block_attr({"sparse":True})
                Y[II_1_indices[vr_1, vio_1, vii_1], vfo] = Y[II_1_indices[vr_1, vio_1, vii_1], vfo] + A_1[vr_1, vio_1, vii_1, vj_1] * W[mid_1[vr_1, vio_1], vfo, vfi] * X[J_1_indices[vr_1, vio_1, vii_1, vj_1], vfi]
yzh119 commented 2 years ago

Finished in #42 .