apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.76k stars 3.47k forks source link

[Bug] Missing PackedFunc with Specific Transformation Sequence in Relax Module #17357

Open Thrsu opened 2 months ago

Thrsu commented 2 months ago

I encountered an issue while running a Relax module with a specific transformation sequence. Specifically, when FuseTIR() is applied once, the VM fails to find the PackedFunc fused_relax_nn_attention_cutlass_gv. However, when the FuseTIR() optimization is applied again before AllocateWorkspace(), the problem disappears.

Expected behavior

The script is expected to run successfully without errors.

Actual behavior

InternalError: Check failed: (func.defined()) is false: Error: Cannot find PackedFunc fused_relax_nn_attention_cutlass_gv in either Relax VM kernel library, or in TVM runtime PackedFunc registry, or in global Relax functions of the VM executable

Steps to reproduce

The following script reproduces the issue: ```python import tvm from tvm import relax from tvm.script import ir as I from tvm.script import tir as T from tvm.script import relax as R @I.ir_module class Module: @T.prim_func(private=True) def attention(q_1: T.Buffer((T.int64(32), T.int64(8), T.int64(16), T.int64(8)), "float16"), k_1: T.Buffer((T.int64(32), T.int64(8), T.int64(16), T.int64(8)), "float16"), v_1: T.Buffer((T.int64(32), T.int64(8), T.int64(16), T.int64(8)), "float16"), T_transpose: T.Buffer((T.int64(32), T.int64(8), T.int64(16), T.int64(8)), "float16")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): T_transpose_1 = T.alloc_buffer((T.int64(32), T.int64(16), T.int64(8), T.int64(8)), "float16") T_reshape = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16") T_transpose_2 = T.alloc_buffer((T.int64(32), T.int64(16), T.int64(8), T.int64(8)), "float16") T_reshape_1 = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16") T_batch_matmul_NT = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16") T_divide = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16") T_softmax_maxelem = T.alloc_buffer((T.int64(512), T.int64(8)), "float16") T_softmax_exp = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16") T_softmax_expsum = T.alloc_buffer((T.int64(512), T.int64(8)), "float16") T_softmax_norm = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16") T_transpose_3 = T.alloc_buffer((T.int64(32), T.int64(16), T.int64(8), T.int64(8)), "float16") T_reshape_2 = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16") T_batch_matmul_NN = T.alloc_buffer((T.int64(512), T.int64(8), T.int64(8)), "float16") T_reshape_3 = T.alloc_buffer((T.int64(32), T.int64(16), T.int64(8), T.int64(8)), "float16") for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(16), T.int64(8), T.int64(8)): with T.block("T_transpose"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(q_1[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3]) T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3] = q_1[v_ax0, v_ax2, v_ax1, v_ax3] for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(8), T.int64(8)): with T.block("T_reshape"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_transpose_1[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) T_reshape[v_ax0, v_ax1, v_ax2] = T_transpose_1[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)] for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(16), T.int64(8), T.int64(8)): with T.block("T_transpose_1"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(k_1[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3]) T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3] = k_1[v_ax0, v_ax2, v_ax1, v_ax3] for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(8), T.int64(8)): with T.block("T_reshape_1"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]) T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2]) T_reshape_1[v_ax0, v_ax1, v_ax2] = T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)] for b, i, j, k in T.grid(T.int64(512), T.int64(8), T.int64(8), T.int64(8)): with T.block("T_batch_matmul_NT"): v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k]) T.reads(T_reshape[v_b, v_i, v_k], T_reshape_1[v_b, v_j, v_k]) T.writes(T_batch_matmul_NT[v_b, v_i, v_j]) T.block_attr({"layout_free_placeholders": [T_reshape_1]}) with T.init(): T_batch_matmul_NT[v_b, v_i, v_j] = T.float16(0) T_batch_matmul_NT[v_b, v_i, v_j] = T_batch_matmul_NT[v_b, v_i, v_j] + T_reshape[v_b, v_i, v_k] * T_reshape_1[v_b, v_j, v_k] for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(8), T.int64(8)): with T.block("T_divide"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_batch_matmul_NT[v_ax0, v_ax1, v_ax2]) T.writes(T_divide[v_ax0, v_ax1, v_ax2]) T_divide[v_ax0, v_ax1, v_ax2] = T_batch_matmul_NT[v_ax0, v_ax1, v_ax2] / T.sqrt(T.float16(8)) for i0, i1, k in T.grid(T.int64(512), T.int64(8), T.int64(8)): with T.block("T_softmax_maxelem"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(T_divide[v_i0, v_i1, v_k]) T.writes(T_softmax_maxelem[v_i0, v_i1]) with T.init(): T_softmax_maxelem[v_i0, v_i1] = T.float16(-65504) T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, v_i1], T_divide[v_i0, v_i1, v_k]) for i0, i1, i2 in T.grid(T.int64(512), T.int64(8), T.int64(8)): with T.block("T_softmax_exp"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_divide[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1]) T.writes(T_softmax_exp[v_i0, v_i1, v_i2]) T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(T_divide[v_i0, v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1]) for i0, i1, k in T.grid(T.int64(512), T.int64(8), T.int64(8)): with T.block("T_softmax_expsum"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) T.reads(T_softmax_exp[v_i0, v_i1, v_k]) T.writes(T_softmax_expsum[v_i0, v_i1]) with T.init(): T_softmax_expsum[v_i0, v_i1] = T.float16(0) T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] + T_softmax_exp[v_i0, v_i1, v_k] for i0, i1, i2 in T.grid(T.int64(512), T.int64(8), T.int64(8)): with T.block("T_softmax_norm"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_softmax_exp[v_i0, v_i1, v_i2], T_softmax_expsum[v_i0, v_i1]) T.writes(T_softmax_norm[v_i0, v_i1, v_i2]) T.block_attr({"axis": 2}) T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1] for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(16), T.int64(8), T.int64(8)): with T.block("T_transpose_2"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(v_1[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3]) T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3] = v_1[v_ax0, v_ax2, v_ax1, v_ax3] for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(8), T.int64(8)): with T.block("T_reshape_2"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) T.reads(T_transpose_3[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]) T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2]) T_reshape_2[v_ax0, v_ax1, v_ax2] = T_transpose_3[((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(512) // T.int64(16), ((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(16), (v_ax2 // T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)] for b, i, j, k in T.grid(T.int64(512), T.int64(8), T.int64(8), T.int64(8)): with T.block("T_batch_matmul_NN"): v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k]) T.reads(T_softmax_norm[v_b, v_i, v_k], T_reshape_2[v_b, v_k, v_j]) T.writes(T_batch_matmul_NN[v_b, v_i, v_j]) T.block_attr({"layout_free_placeholders": [T_reshape_2]}) with T.init(): T_batch_matmul_NN[v_b, v_i, v_j] = T.float16(0) T_batch_matmul_NN[v_b, v_i, v_j] = T_batch_matmul_NN[v_b, v_i, v_j] + T_softmax_norm[v_b, v_i, v_k] * T_reshape_2[v_b, v_k, v_j] for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(16), T.int64(8), T.int64(8)): with T.block("T_reshape_3"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_batch_matmul_NN[(v_ax0 * T.int64(16) + (v_ax3 // T.int64(8) + v_ax2) // T.int64(8) + v_ax1) % T.int64(512), (v_ax3 // T.int64(8) + v_ax2) % T.int64(8), v_ax3 % T.int64(8)]) T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_batch_matmul_NN[(v_ax0 * T.int64(16) + (v_ax3 // T.int64(8) + v_ax2) // T.int64(8) + v_ax1) % T.int64(512), (v_ax3 // T.int64(8) + v_ax2) % T.int64(8), v_ax3 % T.int64(8)] for ax0, ax1, ax2, ax3 in T.grid(T.int64(32), T.int64(8), T.int64(16), T.int64(8)): with T.block("T_transpose_3"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(T_reshape_3[v_ax0, v_ax2, v_ax1, v_ax3]) T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_3[v_ax0, v_ax2, v_ax1, v_ax3] @R.function def entry_b(q: R.Tensor((32, 8, 16, 8), dtype="float16"), k: R.Tensor((32, 8, 16, 8), dtype="float16"), v: R.Tensor((32, 8, 16, 8), dtype="float16")) -> R.Tensor((32, 8, 16, 8), dtype="float16"): cls = Module with R.dataflow(): lv: R.Tensor((32, 8, 16, 8), dtype="float16") = cls.fused_relax_nn_attention_cutlass(q, k, v) R.output(lv) return lv @R.function def fused_relax_nn_attention_cutlass(q: R.Tensor((32, 8, 16, 8), dtype="float16"), k: R.Tensor((32, 8, 16, 8), dtype="float16"), v: R.Tensor((32, 8, 16, 8), dtype="float16")) -> R.Tensor((32, 8, 16, 8), dtype="float16"): R.func_attr({"Codegen": "cutlass", "WorkspaceSize": 65536}) cls = Module @R.function def gv(q_1: R.Tensor((32, 8, 16, 8), dtype="float16"), k_1: R.Tensor((32, 8, 16, 8), dtype="float16"), v_1: R.Tensor((32, 8, 16, 8), dtype="float16")) -> R.Tensor((32, 8, 16, 8), dtype="float16"): R.func_attr({"Composite": "cutlass.attention", "Primitive": 1, "WorkspaceSize": 65536}) with R.dataflow(): gv_2 = R.call_tir(cls.attention, (q_1, k_1, v_1), out_sinfo=R.Tensor((32, 8, 16, 8), dtype="float16")) R.output(gv_2) return gv_2 gv1: R.Tensor((32, 8, 16, 8), dtype="float16") = gv(q, k, v) return gv1 mod = Module # crash mod = tvm.transform.Sequential([relax.transform.FuseTIR(), relax.transform.LambdaLift(), relax.transform.AllocateWorkspace()])(mod) # pass #mod = tvm.transform.Sequential([relax.transform.FuseTIR(), relax.transform.LambdaLift(), relax.transform.FuseTIR(), relax.transform.AllocateWorkspace()])(mod) with tvm.transform.PassContext(opt_level=4): ex = relax.build(mod, target='llvm') vm = relax.VirtualMachine(ex, tvm.cpu()) ```

Any guidance on whether this is a bug or a known order dependency would be greatly appreciated. @Lunderberg

Lunderberg commented 1 month ago

This is a bit of a bug and a bit of an ordering dependency.

  1. The LambdaLift pass extracts local lambda functions into the module. However, FuseOps and FuseOpsByPattern use local lambda functions to represent functions that will be replaced with specific kernel invocations.
  2. The AllocateWorkspace pass adds a new workspace parameter to all top-level functions that have a "WorkspaceSize" attribute, and updates all other functions to provide the new workspace However, if there is a call from one function with a "WorkspaceSize" attribute to another such function, it gets left with a dangling GlobalVar.
  3. The FuseTIR pass removes the Relax function altogether, replacing it with a PrimFunc, avoiding the issue altogether. It only inspects mod->functions, and not local lambda functions, which is why it only had an effect after LambdaLift.

There's a couple of options for short-term fixes, and a couple of options for long-term fixes.

Unfortunately, I don't have time to implement the medium/long term solutions at the moment, but could help guide somebody in their implementation if there's interest.

Thrsu commented 1 month ago

Thank you very much for your thorough analysis and explanation of the root cause of the bug, as well as the detailed guidance on how to address it. Unfortunately, I'm not too familiar with the relax source code, which means I might struggle with submitting a PR to fix this myself. I do hope someone with the right expertise and interest can pick this up.

Thanks again for all your help, and I'm looking forward to seeing this issue tackled by the community!