tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

[Bug] `RewriteDataflowReshape` pass fails at certain stride ops #444

Closed sunggg closed 1 year ago

sunggg commented 1 year ago

RewriteDataflowReshape pass does not handle the special case of slice where its steps are all ones by not checking the size of the total tensor. As a result, it throws the error like follows:

TVMError: Reshape expects the new shape to be convertible from the old shape. However, the old shape is R.shape([20, 10, 5]), with product T.int64(1000), while the new shape is R.shape([3, 10, 5]), with product T.int64(150)

Here is the reproducible code.

import tvm
from tvm import relax
from tvm.script import relax as R, tir as T

### Case 1: Stride with stride of non-ones. This works fine.
@tvm.script.ir_module
class SliceStrideNonOne:
    @R.function
    def main(
        x: R.Tensor((8, 9, 10, 10), dtype="float32")
    ) -> R.Tensor((4, 9, 10, 3), dtype="float32"):
        gv = R.call_tir(strided_slice, (x,), R.Tensor((4, 9, 10, 3), dtype="float32"))
        return gv

    @T.prim_func
    def strided_slice(
        rxplaceholder: T.Buffer[(T.int64(8), T.int64(9), T.int64(10), T.int64(10)), "float32"],
        T_strided_slice_with_axes: T.Buffer[
            (T.int64(4), T.int64(9), T.int64(10), T.int64(3)), "float32"
        ],
    ):
        T.func_attr({"tir.noalias": True})
        for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(9), T.int64(10), T.int64(3)):
            with T.block("T_strided_slice_with_axes"):
                ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(
                    rxplaceholder[
                        ax0 * T.int64(2) + T.int64(1), ax1, ax2, T.int64(8) - ax3 * T.int64(3)
                    ]
                )
                T.writes(T_strided_slice_with_axes[ax0, ax1, ax2, ax3])
                T_strided_slice_with_axes[ax0, ax1, ax2, ax3] = rxplaceholder[
                    ax0 * T.int64(2) + T.int64(1), ax1, ax2, T.int64(8) - ax3 * T.int64(3)
                ]

before = SliceStrideNonOne
after = relax.transform.RewriteDataflowReshape()(before)
# passes equality test
tvm.ir.assert_structural_equal(before, after)

### Case 2: Stride with stride of ones. This throws error.
@tvm.script.ir_module
class SliceStrideOne:
    @R.function
    def main(x: R.Tensor((20, 10, 5), dtype="float32")) -> R.Tensor((3, 10, 5), dtype="float32"):
        with R.dataflow():
            lv = R.call_tir(strided_slice, (x,), out_sinfo=R.Tensor((3, 10, 5), dtype="float32"))
            gv: R.Tensor((3, 10, 5), dtype="float32") = lv
            R.output(gv)
        return gv

    @T.prim_func
    def strided_slice(
        rxplaceholder: T.Buffer((T.int64(20), T.int64(10), T.int64(5)), "float32"),
        T_strided_slice_with_axes: T.Buffer((T.int64(3), T.int64(10), T.int64(5)), "float32"),
    ):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(3), T.int64(10), T.int64(5)):
            with T.block("T_strided_slice_with_axes"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2])
                T.writes(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2])
                T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2] = rxplaceholder[v_ax0, v_ax1, v_ax2]

before = SliceStrideOne
after = relax.transform.RewriteDataflowReshape()(before)
after.show()

cc. @jwfromm @MasterJH5574 @tqchen

yongwww commented 1 year ago

fixed in #445