mlc-ai / mlc-llm

Universal LLM Deployment Engine with ML Compilation
https://llm.mlc.ai/
Apache License 2.0
17.74k stars 1.41k forks source link

[Bug] Check failed: (args.size() == initial_indices_orig.size()) is false #2276

Open jpf888 opened 2 months ago

jpf888 commented 2 months ago

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

  1. I built a model with op conv2d
  2. one of the calculation graphs is : permute dim--> conv2d-->layernorm
  3. I encountered the following problems during compilation.

i think this problem is caused by the fusion of permute and conv operators after dl.gpu.Matmul(), resulting in a mismatch between buffer shape and index_map shape.

1、error log tvm.error.InternalError: Traceback (most recent call last): 4: operator() at /workspace/tvm-unity/src/tir/schedule/schedule.cc:287 3: tvm::tir::TracedScheduleNode::TransformLayout(tvm::tir::BlockRV const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool) at /workspace/tvm-unity/src/tir/schedule/traced_schedule.cc:678 2: tvm::tir::ConcreteScheduleNode::TransformLayout(tvm::tir::BlockRV const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool) at /workspace/tvm-unity/src/tir/schedule/concrete_schedule.cc:993 1: tvm::tir::TransformLayout(tvm::tir::ScheduleState, tvm::tir::StmtSRef const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool) at /workspace/tvm-unity/src/tir/schedule/primitive/layout_transformation.cc:1160 0: tvm::tir::LegalizeIndexMapDType(tvm::tir::IndexMap const&, tvm::runtime::Array<tvm::PrimExpr, void> const&) at /workspace/tvm-unity/src/tir/schedule/primitive/layout_transformation.cc:1106 File "/workspace/tvm-unity/src/tir/schedule/primitive/layout_transformation.cc", line 1106 InternalError: Check failed: (args.size() == initial_indices_orig.size()) is false:

2、other message 1). T.index_map(lambda i0, i1, i2, i3, i4, i5: (T.int64(0), i1 * T.int64(64) + i2, i3)) ??? is not match?

2).with T.block("conv2d_nchw", no_realize=True): v_nn = T.axis.spatial(T.int64(1)) v_ff = T.axis.spatial(T.int64(256)) v_yy = T.axis.spatial(T.int64(64)) v_xx = T.axis.spatial(T.int64(64)) v_rc = T.axis.reduce(T.int64(768)) v_ry = T.axis.reduce(T.int64(1)) v_rx = T.axis.reduce(T.int64(1)) pad_temp = T.Buffer((T.int64(1), T.int64(768), T.int64(64), T.int64(64)), "float16") B = T.Buffer((T.int64(256), T.int64(768), T.int64(1), T.int64(1)), "float16") T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], B[v_ff, v_rc, v_ry, v_rx]) conv2d_nchw = T.Buffer((T.int64(1), T.int64(256), T.int64(64), T.int64(64)), "float16") T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx]) with T.init(): conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float16(0) conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * B[v_ff, v_rc, v_ry, v_rx]

3). `@T.prim_func(private=True) def main(permute_dims161: T.Buffer((T.int64(1), T.int64(768), T.int64(64), T.int64(64)), "float16"), vision_tower_vision_tower_high_neck_0_weight1: T.Buffer((T.int64(256), T.int64(768), T.int64(1), T.int64(1)), "float16"), compute_intermediate: T.Buffer((T.int64(1), T.int64(256), T.int64(64), T.int64(64)), "float32")): T.func_attr({"tir.noalias": T.bool(True)})

with T.block("root"):

pad_temp = T.alloc_buffer((T.int64(1), T.int64(768), T.int64(64), T.int64(64)), "float16")
conv2d_nchw_intermediate = T.alloc_buffer((T.int64(1), T.int64(256), T.int64(64), T.int64(64)), "float16")
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(768), T.int64(64), T.int64(64)):
    with T.block("pad_temp"):
        v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
        T.reads(permute_dims161[v_i0, v_i1, v_i2, v_i3])
        T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
        pad_temp[v_i0, v_i1, v_i2, v_i3] = permute_dims161[v_i0, v_i1, v_i2, v_i3]
for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(1), T.int64(256), T.int64(64), T.int64(64), T.int64(768), T.int64(1), T.int64(1)):
    with T.block("conv2d_nchw"):
        v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx])
        T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], vision_tower_vision_tower_high_neck_0_weight1[v_ff, v_rc, v_ry, v_rx])
        T.writes(conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx])
        with T.init():
            conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] = T.float16(0)
        conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * vision_tower_vision_tower_high_neck_0_weight1[v_ff, v_rc, v_ry, v_rx]
for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(256), T.int64(64), T.int64(64)):
    with T.block("compute"):
        v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
        T.reads(conv2d_nchw_intermediate[v_i0, v_i1, v_i2, v_i3])
        T.writes(compute_intermediate[v_i0, v_i1, v_i2, v_i3])
        compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", conv2d_nchw_intermediate[v_i0, v_i1, v_i2, v_i3])`

T.index_map(lambda i0, i1, i2, i3, i4, i5: (T.int64(0), i1 * T.int64(64) + i2, i3))

Expected behavior

Environment

Additional context

tqchen commented 1 month ago

Thanks for reporting if it is possibe to get a minimum repro that would be helpful. You can do so by dumping out the TVMScript before the transform, minimize it and run the transform you mentioned

senlyu163 commented 1 month ago

@jpf888 I met the same problem, did you solve it?

senlyu163 commented 1 month ago

Thanks for reporting if it is possibe to get a minimum repro that would be helpful. You can do so by dumping out the TVMScript before the transform, minimize it and run the transform you mentioned

Hi, this bugs can repro like:

from tvm.relax.frontend import nn

class _Conv2d(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.modules.Conv2D(
            in_channels=768,
            out_channels=256,
            kernel_size=1,
            padding=0,
            bias=False
        )
    def forward(self, x: nn.Tensor):
        return self.conv(x)

from tvm.relax.frontend.nn import spec
forward_spec = {
    "forward": {
        "x": spec.Tensor([1, 768, 64, 64], dtype="float32")
    }
}
_conv2d_mod, params = _Conv2d().export_tvm(
    spec=forward_spec,
    debug=True
)

mod = _conv2d_mod

def _pipeline(mod):
    seq = tvm.transform.Sequential(
        [
            tvm.relax.transform.LegalizeOps(),
            tvm.relax.transform.AnnotateTIROpPattern(),
            tvm.relax.transform.FoldConstant(),
            tvm.relax.transform.FuseOps(),
            tvm.relax.transform.FuseTIR(),
            dl.ApplyDefaultSchedule(
                dl.gpu.Matmul(),
                dl.gpu.GEMV(),
                dl.gpu.Reduction(),
                dl.gpu.GeneralReduction(),
                dl.gpu.Fallback(),
            )
        ]
    )
    mod = seq(mod)
    return mod

with tvm.target.Target("nvidia/geforce-rtx-4090", host="llvm"):
    mod = _pipeline(mod)

mod.show()

Only when kernel_size is equal to 1, dl.gpu.Matmul will report an error after tvm.relax.transform.LegalizeOps(). In a MLLM model, image embedding may be involved, and this operation (kernel_size equals to 1 in conv2d) may be used.

LeiWang1999 commented 1 month ago

Hi @senlyu163 looks like It's a known issue when applying dlight on conv2d with a kernel size of 1. This issue arises because the reindex schedule performs simplifications on the expr. To address this, I previously created a draft PR. You can merge the relevant changes and modify the normalize_to_matmul function of dlight.

checkout this draft pr: https://github.com/apache/tvm/pull/16440

The key component related to this issue is the addition of a skip_simplify flag to cache_reindex. You can apply the relevant changes as follows:

def normalize_to_matmul(sch: tir.Schedule,
                        main_block: BlockRV,
                        layout: Optional[List[str]] = None) -> Optional[tir.Schedule]:
    if layout is None:
        layout = ["n", "t", "n"]
    block_stmt = sch.get(main_block)

    # Let layout be 'a' to auto infer the layout
    index_maps = get_index_map(block_stmt, layout=layout)
    if index_maps is None:
        logger.debug("Cannot find the appropriate index map for tensorcore")
        return None

    matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps

    # Use `skip_simplify` to avoid the bug in the 1x1 conv
    block = sch.reindex(main_block, ("read", 0), skip_simplify=True)
    sch.transform_layout(block, ("write", 0), a_index_map)
    block = sch.reindex(main_block, ("read", 1), skip_simplify=True)
    sch.transform_layout(block, ("write", 0), b_index_map)
    block = sch.reindex(main_block, ("write", 0), skip_simplify=True)
    sch.transform_layout(block, ("read", 0), c_index_map)
    sch.transform_block_layout(main_block, matmul_index_map)
    sch.mod["main"] = sch.mod["main"].with_attr("dlight.tensorcore_prenormlized", True)
    return sch