apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.56k stars 3.43k forks source link

[Bug] MetaSchedule layout rewrite causes workload mismatch #13148

Open vinx13 opened 1 year ago

vinx13 commented 1 year ago

TOPI batch matmul mistakenly mark the RHS tensor of batch matmul as layout free placeholder when it is a variable. As a result, RewriteLayout is applied to it and it can't be constant-folded in Relay. It results in workload mismatch because the meta_schedule_layout_transform op is fused with other operators, resulting a new workload that hasn't been tuned.

Expected behavior

Successfully tune and compile the Relay function.

Actual behavior

One workload is missing from tuning database.

src/relay/backend/te_compiler_cache.cc:544: Warning: Cannot find workload: vm_mod_fused_transpose_meta_schedule_layout_transform
# from tvm.script import tir as T
@T.prim_func
def func(p0: T.Buffer[(12, 64, 197), "int8"], T_meta_schedule_layout_trans: T.Buffer[(12, 64, 197), "int8"]) -> None:
    # function attr dict
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    # body
    # with T.block("root")
    T_transpose = T.alloc_buffer([12, 197, 64], dtype="int8")
    for i0, i1, i2 in T.grid(12, 197, 64):
        with T.block("T_transpose"):
            ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
            T.reads(p0[ax0, ax2, ax1])
            T.writes(T_transpose[ax0, ax1, ax2])
            T_transpose[ax0, ax1, ax2] = p0[ax0, ax2, ax1]
    for i0, i1, i2 in T.grid(12, 64, 197):
        with T.block("T_meta_schedule_layout_trans"):
            ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
            T.reads(T_transpose[ax0, ax2, ax1])
            T.writes(T_meta_schedule_layout_trans[ax0, ax1, ax2])
            T_meta_schedule_layout_trans[ax0, ax1, ax2] = T_transpose[ax0, ax2, ax1]

Environment

TVM v0.11.dev 5eab64885ad4

Steps to reproduce

import tvm
from tvm import relay
x = relay.var("x", shape=(12, 197, 64), dtype="int8")
y = relay.var("y", shape=[12, 64, 197], dtype="int8")
y1 = relay.transpose(y, [0, 2, 1])
mm = relay.nn.batch_matmul(x, y1, out_dtype="int32", transpose_b=True)

func = relay.Function([x, y], mm)
mod = tvm.ir.IRModule({"main": func})

import tvm.meta_schedule as ms
target = tvm.target.Target("aws/cpu/c5.12xlarge")
database = ms.relay_integration.tune_relay(mod, {}, target=target, work_dir="./work_dir", max_trials_global=200)
lib = ms.relay_integration.compile_relay(database=database, mod=mod, target=target, params={}, backend='vm')

cc @zxybazh @junrushao

masahi commented 1 year ago

A related issue from AS, probably due to the same issue https://github.com/apache/tvm/issues/9476

masahi commented 1 year ago

Actually, if I replace the above repro with

mm = relay.nn.batch_matmul(x, x, out_dtype="int32", transpose_b=True)
func = relay.Function([x], mm)

I get a segfault from MS. Disabling RewriteLayout made the error go away.