apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.82k stars 3.48k forks source link

[Bug] [CodeGen] InternalError: Check failed: (is_zero(op->min)) is false during CodeGen #17386

Open talha-ahsan opened 2 months ago

talha-ahsan commented 2 months ago

Expected behavior

A successful compilation, or messages about what part of the program being compiled fails. As far as I can tell this should at least surface a more user-friendly message if the TIR is invalid.

Actual behavior

I get a long trace (added below) which ends with this message: InternalError: Check failed: (is_zero(op->min)) is false:

This seems to fail in the tir.LowerIntrin pass

Steps to reproduce

This script causes the exception to occur:

import tvm
from tvm import tir
from tvm.tir.analysis.analysis import verify_well_formed, verify_memory

from tvm.script import tir as T

@T.prim_func
def tvmgen_default_fused_nn_conv2d_3(p0: T.Buffer((1, 256, 56, 56), "float32"), p1: T.Buffer((256, 256, 3, 3), "float32"), output_unpack: T.Buffer((1, 256, 56, 56), "float32")):
    T.func_attr({"from_legacy_te_schedule": T.bool(True), "hash": "032dbe00302af996", "target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, "keys": ["cpu"], "kind": "llvm", "tag": ""}), "tir.noalias": T.bool(True)})
    data_vec = T.allocate([802816], "float32", "global")
    data_pad = T.allocate([861184], "float32", "global")
    data_vec_1 = T.Buffer((802816,), data=data_vec)
    for bs_c_fused_h_fused in T.parallel(3584):
        for w, vc in T.grid(56, 4):
            p0_1 = T.Buffer((802816,), data=p0.data)
            data_vec_1[bs_c_fused_h_fused * 224 + w * 4 + vc] = p0_1[bs_c_fused_h_fused // 56 * 12544 + vc * 3136 + bs_c_fused_h_fused % 56 * 56 + w]
    data_pad_1 = T.Buffer((861184,), data=data_pad)
    for i0_i1_fused_i2_fused in T.parallel(3712):
        for i3 in range(58):
            cse_var_2: T.int32 = i0_i1_fused_i2_fused % 58
            cse_var_1: T.int32 = i3 * 4
            data_pad_1[i0_i1_fused_i2_fused * 232 + cse_var_1:i0_i1_fused_i2_fused * 232 + cse_var_1 + 4] = T.if_then_else(1 <= cse_var_2 and cse_var_2 < 57 and 1 <= i3 and i3 < 57, data_vec_1[i0_i1_fused_i2_fused // 58 * 12544 + cse_var_2 * 224 + cse_var_1 - 228:i0_i1_fused_i2_fused // 58 * 12544 + cse_var_2 * 224 + cse_var_1 - 228 + 4], T.Broadcast(T.float32(0), 4))
    data_vec_2 = T.Buffer((589824,), data=data_vec)
    for occ_k_h_fused in T.parallel(192):
        for icc, k_w, icb in T.grid(64, 3, 4):
            cse_var_4: T.int32 = occ_k_h_fused % 3
            cse_var_3: T.int32 = occ_k_h_fused // 3 * 9216
            p1_1 = T.Buffer((589824,), data=p1.data)
            data_vec_2[cse_var_3 + icc * 144 + cse_var_4 * 48 + k_w * 16 + icb * 4:cse_var_3 + icc * 144 + cse_var_4 * 48 + k_w * 16 + icb * 4 + 4] = p1_1[cse_var_3 + icc * 36 + icb * 9 + cse_var_4 * 3 + k_w:cse_var_3 + icc * 36 + icb * 9 + cse_var_4 * 3 + k_w + 9216:2304]
    for n_c_outer_fused_h_fused in T.parallel(3584):
        conv2d_NCHWc = T.allocate([56], "float32x4", "global")
        conv2d_NCHWc_global = T.allocate([28], "float32x4", "global")
        conv2d_NCHWc_1 = T.Buffer((56,), "float32x4", data=conv2d_NCHWc)
        for ow_outer in range(2):
            conv2d_NCHWc_global_1 = T.Buffer((28,), "float32x4", data=conv2d_NCHWc_global)
            conv2d_NCHWc_global_1[0] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[1] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[2] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[3] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[4] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[5] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[6] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[7] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[8] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[9] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[10] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[11] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[12] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[13] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[14] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[15] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[16] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[17] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[18] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[19] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[20] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[21] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[22] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[23] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[24] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[25] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[26] = T.Broadcast(T.float32(0), 4)
            conv2d_NCHWc_global_1[27] = T.Broadcast(T.float32(0), 4)
            for ic_outer, kh, kw, ic_inner in T.grid(64, 3, 3, 4):
                cse_var_6: T.int32 = n_c_outer_fused_h_fused // 56 * 9216 + ic_outer * 144 + kh * 48 + kw * 16 + ic_inner * 4
                cse_var_5: T.int32 = ic_outer * 13456 + kh * 232 + n_c_outer_fused_h_fused % 56 * 232 + ow_outer * 112 + kw * 4 + ic_inner
                conv2d_NCHWc_global_1[0] = conv2d_NCHWc_global_1[0] + T.Broadcast(data_pad_1[cse_var_5], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[1] = conv2d_NCHWc_global_1[1] + T.Broadcast(data_pad_1[cse_var_5 + 4], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[2] = conv2d_NCHWc_global_1[2] + T.Broadcast(data_pad_1[cse_var_5 + 8], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[3] = conv2d_NCHWc_global_1[3] + T.Broadcast(data_pad_1[cse_var_5 + 12], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[4] = conv2d_NCHWc_global_1[4] + T.Broadcast(data_pad_1[cse_var_5 + 16], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[5] = conv2d_NCHWc_global_1[5] + T.Broadcast(data_pad_1[cse_var_5 + 20], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[6] = conv2d_NCHWc_global_1[6] + T.Broadcast(data_pad_1[cse_var_5 + 24], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[7] = conv2d_NCHWc_global_1[7] + T.Broadcast(data_pad_1[cse_var_5 + 28], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[8] = conv2d_NCHWc_global_1[8] + T.Broadcast(data_pad_1[cse_var_5 + 32], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[9] = conv2d_NCHWc_global_1[9] + T.Broadcast(data_pad_1[cse_var_5 + 36], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[10] = conv2d_NCHWc_global_1[10] + T.Broadcast(data_pad_1[cse_var_5 + 40], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[11] = conv2d_NCHWc_global_1[11] + T.Broadcast(data_pad_1[cse_var_5 + 44], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[12] = conv2d_NCHWc_global_1[12] + T.Broadcast(data_pad_1[cse_var_5 + 48], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[13] = conv2d_NCHWc_global_1[13] + T.Broadcast(data_pad_1[cse_var_5 + 52], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[14] = conv2d_NCHWc_global_1[14] + T.Broadcast(data_pad_1[cse_var_5 + 56], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[15] = conv2d_NCHWc_global_1[15] + T.Broadcast(data_pad_1[cse_var_5 + 60], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[16] = conv2d_NCHWc_global_1[16] + T.Broadcast(data_pad_1[cse_var_5 + 64], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[17] = conv2d_NCHWc_global_1[17] + T.Broadcast(data_pad_1[cse_var_5 + 68], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[18] = conv2d_NCHWc_global_1[18] + T.Broadcast(data_pad_1[cse_var_5 + 72], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[19] = conv2d_NCHWc_global_1[19] + T.Broadcast(data_pad_1[cse_var_5 + 76], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[20] = conv2d_NCHWc_global_1[20] + T.Broadcast(data_pad_1[cse_var_5 + 80], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[21] = conv2d_NCHWc_global_1[21] + T.Broadcast(data_pad_1[cse_var_5 + 84], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[22] = conv2d_NCHWc_global_1[22] + T.Broadcast(data_pad_1[cse_var_5 + 88], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[23] = conv2d_NCHWc_global_1[23] + T.Broadcast(data_pad_1[cse_var_5 + 92], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[24] = conv2d_NCHWc_global_1[24] + T.Broadcast(data_pad_1[cse_var_5 + 96], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[25] = conv2d_NCHWc_global_1[25] + T.Broadcast(data_pad_1[cse_var_5 + 100], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[26] = conv2d_NCHWc_global_1[26] + T.Broadcast(data_pad_1[cse_var_5 + 104], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
                conv2d_NCHWc_global_1[27] = conv2d_NCHWc_global_1[27] + T.Broadcast(data_pad_1[cse_var_5 + 108], 4) * data_vec_2[cse_var_6:cse_var_6 + 4]
            for ow_inner in range(28):
                conv2d_NCHWc_1[ow_outer * 28 + ow_inner] = conv2d_NCHWc_global_1[ow_inner]
        for w_outer in range(2):
            for w_inner in range(-1075824214, -1075824186):
                cse_var_7: T.int32 = w_outer * 28
                output_unpack_1 = T.Buffer((802816,), data=output_unpack.data)
                output_unpack_1[n_c_outer_fused_h_fused // 56 * 12544 + n_c_outer_fused_h_fused % 56 * 56 + cse_var_7 + w_inner:n_c_outer_fused_h_fused // 56 * 12544 + n_c_outer_fused_h_fused % 56 * 56 + cse_var_7 + w_inner + 12544:3136] = conv2d_NCHWc_1[cse_var_7 + cse_var_7]

func = tvmgen_default_fused_nn_conv2d_3
mod = tvm.ir.IRModule({'main': func})
if not verify_well_formed(mod) and verify_memory(func):
    print("Validation failed")
else: 
    with tvm.transform.PassContext(opt_level=0):
        nopt_mod = tvm.build(mod)
    print("Success!")

Dump

Traceback (most recent call last):
  File "<path_to_script>/TVMBugReport3/reprod.py", line 113, in <module>
    nopt_mod = tvm.build(mod)
  File "<path_to_tvm>/tvm/python/tvm/driver/build_module.py", line 297, in build
    rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
  File "<path_to_tvm>/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "<path_to_tvm>/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  119: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::$_5>(tvm::$_5, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  118: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
  117: tvm::codegen::Build(tvm::IRModule, tvm::Target)
  116: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::codegen::$_8>(tvm::codegen::$_8, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  115: tvm::codegen::LLVMModuleNode::Init(tvm::IRModule const&, tvm::Target const&)
  114: void tvm::codegen::CodeGenLLVM::AddFunctionsOrdered<tvm::runtime::Map<tvm::GlobalVar, tvm::BaseFunc, void, void>::iterator, tvm::codegen::CodeGenLLVM::AddFunctionsOrdered<tvm::runtime::Map<tvm::GlobalVar, tvm::BaseFunc, void, void>::iterator>(tvm::runtime::Map<tvm::GlobalVar, tvm::BaseFunc, void, void>::iterator, tvm::runtime::Map<tvm::GlobalVar, tvm::BaseFunc, void, void>::iterator)::{lambda(auto:1)#1}>(tvm::runtime::Map<tvm::GlobalVar, tvm::BaseFunc, void, void>::iterator, tvm::runtime::Map<tvm::GlobalVar, tvm::BaseFunc, void, void>::iterator, tvm::codegen::CodeGenLLVM::AddFunctionsOrdered<tvm::runtime::Map<tvm::GlobalVar, tvm::BaseFunc, void, void>::iterator>(tvm::runtime::Map<tvm::GlobalVar, tvm::BaseFunc, void, void>::iterator, tvm::runtime::Map<tvm::GlobalVar, tvm::BaseFunc, void, void>::iterator)::{lambda(auto:1)#1})
  113: tvm::codegen::CodeGenCPU::AddFunction(tvm::GlobalVar const&, tvm::tir::PrimFunc const&)
  112: tvm::codegen::CodeGenLLVM::AddFunctionInternal(tvm::GlobalVar const&, tvm::tir::PrimFunc const&)
  111: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  110: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  109: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  108: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  107: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  106: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  105: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  104: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  103: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  102: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AttrStmtNode const*)
  101: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AttrStmtNode const*)
  100: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  99: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  98: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  97: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  96: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AttrStmtNode const*)
  95: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AttrStmtNode const*)
  94: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  93: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  92: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  91: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AttrStmtNode const*)
  90: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AttrStmtNode const*)
  89: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  88: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  87: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  86: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  85: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  84: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  83: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  82: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  81: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  80: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  79: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  78: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  77: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  76: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  75: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  74: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  73: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  72: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  71: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  70: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  69: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  68: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  67: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*)
  66: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  65: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  64: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  63: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  62: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  61: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  60: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  59: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  58: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  57: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  56: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  55: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  54: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  53: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  52: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  51: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  50: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  49: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  48: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*)
  47: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  46: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  45: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  44: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  43: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  42: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  41: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  40: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  39: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  38: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  37: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  36: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  35: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  34: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  33: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  32: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  31: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  30: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  29: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  28: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  27: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*)
  26: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  25: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  24: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  23: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  22: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
  21: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
  20: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*)
  19: tvm::codegen::CodeGenCPU::CreateComputeScope(tvm::tir::AttrStmtNode const*)
  18: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  17: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AttrStmtNode const*)
  16: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AttrStmtNode const*)
  15: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*)
  14: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
  13: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AttrStmtNode const*)
  12: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AttrStmtNode const*)
  11: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*)
  10: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::ForNode const*)
  9: tvm::codegen::CodeGenCPU::CreateParallelLaunch(tvm::tir::Stmt const&, int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)
  8: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::ForNode const*)
  7: tvm::codegen::CodeGenLLVM::CreateSerialFor(llvm::Value*, llvm::Value*, llvm::Value*, tvm::tir::Var const&, tvm::tir::Stmt const&)
  6: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AllocateNode const*)
  5: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AllocateNode const*)
  4: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*)
  3: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::ForNode const*)
  2: tvm::codegen::CodeGenLLVM::CreateSerialFor(llvm::Value*, llvm::Value*, llvm::Value*, tvm::tir::Var const&, tvm::tir::Stmt const&)
  1: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::ForNode const*)
  0: _ZN3tvm7runtime6detail
  File "<path_to_tvm>/tvm/src/target/llvm/codegen_cpu.cc", line 1482
InternalError: Check failed: (is_zero(op->min)) is false: 

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

talha-ahsan commented 1 month ago

@Lunderberg Hi, sorry for the ping, but I was wondering if it would be possible to get some followup on this issue, #17387 and #17388