Open zifeitong opened 1 month ago
Thanks for reporting this issue, I'll take a look.
I tried running inference code again with CUDA_LAUNCH_BLOCKING=1
and I got
tvm._ffi.base.TVMError: Traceback (most recent call last):
2: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_6detail17PackFuncVoidAddr_ILi4ENS0_15CUDAWrappedFuncEEENS0_10PackedFuncET0_RKSt6vectorINS4_1
1: tvm::runtime::CUDAWrappedFunc::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*, void**) const [clone .isra.0]
0: _ZN3tvm7runtime6deta
File "/workspace/tvm/src/runtime/cuda/cuda_module.cc", line 206
TVMError: CUDALaunch Error: CUDA_ERROR_ILLEGAL_INSTRUCTION
grid=(3072,1,1), block=(64,4,1)
and the grid dim and block dim doesn't need to be a flashinfer kernel.
Need more investigation.
Seems grid=(3072,1,1), block=(64,4,1)
was the kernel configuration of TIR NT_matmul
:
@I.ir_module
class Module:
I.module_attrs({"external_mods": [metadata["runtime.Module"][0], metadata["runtime.Module"][1], metadata["runtime.Module"][2], metadata["runtime.Module"][3], metadata["runtime.Module"][4], metadata["runtime.Module"][5], metadata["runtime.Module"][6], metadata["runtime.Module"][7], metadata["runtime.Module"][8], metadata["runtime.Module"][9], metadata["runtime.Module"][10], metadata["runtime.Module"][11]]})
@T.prim_func(private=True)
def NT_matmul(rms_norm65: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), model_layers_0_self_attn_qkv_proj_weight3: T.Buffer((T.int64(12288), T.int64(4096)), "float16"), NT_matmul: T.Buffer((T.int64(1), T.int64(1), T.int64(12288)), "float16")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
NT_matmul_rf_local = T.alloc_buffer((T.int64(256), T.int64(1), T.int64(1), T.int64(12288)), "float16", scope="local")
NT_matmul_rf_local_1 = T.alloc_buffer((T.int64(64), T.int64(1), T.int64(1), T.int64(12288)), "float16", scope="local")
model_layers_0_self_attn_qkv_proj_weight3_local = T.alloc_buffer((T.int64(12288), T.int64(4096)), "float16", scope="local")
rms_norm65_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="shared")
for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(3072), thread="blockIdx.x"):
for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
for ax2_0 in T.serial(T.int64(2), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
for ax2_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
for ax2_2 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
for ax2_3 in T.vectorized(T.int64(8)):
with T.block("rms_norm65_shared"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(2048) + ax2_1 * T.int64(512) + ax2_2 * T.int64(8) + ax2_3)
T.reads(rms_norm65[v0, v1, v2])
T.writes(rms_norm65_shared[v0, v1, v2])
rms_norm65_shared[v0, v1, v2] = rms_norm65[v0, v1, v2]
for u_fused_ax0_fused_fused_2_init in range(T.int64(1)):
for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init in T.vectorized(T.int64(4)):
with T.block("NT_matmul_rf_init"):
vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(256), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1_init)
v0 = T.axis.spatial(T.int64(12288), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init)
T.reads()
T.writes(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0])
NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0)
for ax1_fused_u_fused_0 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0_ax1_fused_0 in range(T.int64(4)):
for ax0_ax1_fused_1 in T.vectorized(T.int64(2)):
with T.block("model_layers_0_self_attn_qkv_proj_weight3_local"):
v0 = T.axis.spatial(T.int64(12288), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1)
v1 = T.axis.spatial(T.int64(4096), ax1_fused_u_fused_0 * T.int64(512) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(8) + ax0_ax1_fused_0 * T.int64(2) + ax0_ax1_fused_1)
T.reads(model_layers_0_self_attn_qkv_proj_weight3[v0, v1])
T.writes(model_layers_0_self_attn_qkv_proj_weight3_local[v0, v1])
model_layers_0_self_attn_qkv_proj_weight3_local[v0, v1] = model_layers_0_self_attn_qkv_proj_weight3[v0, v1]
for u_fused_ax0_fused_fused_2, ax1_fused_u_fused_2 in T.grid(T.int64(1), T.int64(2)):
for ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 in T.vectorized(T.int64(4)):
with T.block("NT_matmul_rf_update"):
vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused = T.axis.spatial(T.int64(256), ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + ax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1)
v0 = T.axis.spatial(T.int64(12288), u_fused_ax0_fused_fused_0 * T.int64(4) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2)
vax1_fused_u_fused_0, vax1_fused_u_fused_2 = T.axis.remap("RR", [ax1_fused_u_fused_0, ax1_fused_u_fused_2])
T.reads(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0], rms_norm65_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(512) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)], model_layers_0_self_attn_qkv_proj_weight3_local[v0, vax1_fused_u_fused_0 * T.int64(512) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)])
T.writes(NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0])
NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] = NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused, T.int64(0), T.int64(0), v0] + rms_norm65_shared[T.int64(0), T.int64(0), vax1_fused_u_fused_0 * T.int64(512) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)] * model_layers_0_self_attn_qkv_proj_weight3_local[v0, vax1_fused_u_fused_0 * T.int64(512) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused // T.int64(4) * T.int64(8) + vax1_fused_u_fused_2 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused % T.int64(4)]
for ax2_fused_0_ax2_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
for ax2_fused_2_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax2_fused_2_1 in T.vectorized(T.int64(1)):
with T.block("NT_matmul_rf_init"):
vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(64), ax0)
v0 = T.axis.spatial(T.int64(12288), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
T.reads()
T.writes(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0])
NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0)
for ax1 in range(T.int64(4)):
with T.block("NT_matmul_rf_update"):
vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1])
v0 = T.axis.spatial(T.int64(12288), u_fused_ax0_fused_fused_0 * T.int64(4) + ax2_fused_0_ax2_fused_1_fused + ax2_fused_2_0 + ax2_fused_2_1)
T.reads(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0], NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0])
T.writes(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0])
NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] = NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0] + NT_matmul_rf_local[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 * T.int64(4) + vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_1, T.int64(0), T.int64(0), v0]
for ax1_fused_2 in range(T.int64(1)):
for ax1_fused_0_ax1_fused_1_fused in T.thread_binding(T.int64(4), thread="threadIdx.y"):
for ax0 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
with T.block("NT_matmul"):
vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(64), ax0)
v0 = T.axis.spatial(T.int64(12288), u_fused_ax0_fused_fused_0 * T.int64(4) + ax1_fused_0_ax1_fused_1_fused + ax1_fused_2)
T.reads(NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0])
T.writes(NT_matmul[T.int64(0), T.int64(0), v0])
with T.init():
NT_matmul[T.int64(0), T.int64(0), v0] = T.float16(0)
NT_matmul[T.int64(0), T.int64(0), v0] = NT_matmul[T.int64(0), T.int64(0), v0] + NT_matmul_rf_local_1[vax1_fused_u_fused_1_ax1_fused_u_fused_3_fused_0, T.int64(0), T.int64(0), v0]
Thank you @yzh119 so much for triaging and locate the issue π We'll further look into it and report back for further updates.
Exact same issue here.
π Bug
I am seeing
illegal instruction
error when running mlc-llm with H100 (Driver Version: 550.54.1 CUDA Version: 12.4). It works fine on another A100 machine (Driver Version: 535.161.08, CUDA Version: 12.2).To Reproduce
Steps to reproduce the behavior:
Expected behavior
Environment
Thanks!