Closed wrongtest-intellif closed 1 week ago
@tvm-bot re-run
If the loop domain depends on other loops, currently there is missing transformations in
CreatePrimFunc
, which lead to undefined variables in lowering.https://discuss.tvm.apache.org/t/compilation-error-for-adaptive-avg-pool2d-relax-op-in-mlc-llm/17784
Hi, I posted this issue on TVM community, and thanks for providing fast reply and fix! I added a follow-up in the forum but the content is still waiting for approval to be public. So I forward it here:
When I tried to reproduce the error with your test code, I met following error, which is somehow similar but not exactly the same to what I met when compiling the model.
`Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/workspace/tvm/python/tvm/driver/build_module.py", line 297, in build
rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
File "/workspace/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
File "/workspace/tvm/src/driver/driver_api.cc", line 532, in operator()
return TIRToRuntime(inputs_arg, host_target);
File "/workspace/tvm/src/driver/driver_api.cc", line 493, in tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
auto pair = SplitMixedModule(ir_module, target, target_host);
File "/workspace/tvm/src/driver/driver_api.cc", line 419, in tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target));
File "/workspace/tvm/src/driver/driver_api.cc", line 290, in tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
mod = seq(std::move(mod));
File "/workspace/tvm/src/tir/transforms/make_packed_api.cc", line 435, in operator()
func = MakePackedAPI(std::move(func));
File "/workspace/tvm/src/tir/transforms/make_packed_api.cc", line 398, in tvm::tir::MakePackedAPI(tvm::tir::PrimFunc)
ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << name_hint << " variables " << undefined
tvm.error.InternalError: Traceback (most recent call last):
5: operator()
at /workspace/tvm/src/driver/driver_api.cc:532
4: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
at /workspace/tvm/src/driver/driver_api.cc:493
3: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
at /workspace/tvm/src/driver/driver_api.cc:419
2: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
at /workspace/tvm/src/driver/driver_api.cc:290
1: operator()
at /workspace/tvm/src/tir/transforms/make_packed_api.cc:435
0: tvm::tir::MakePackedAPI(tvm::tir::PrimFunc)
at /workspace/tvm/src/tir/transforms/make_packed_api.cc:398
File "/workspace/tvm/src/tir/transforms/make_packed_api.cc", line 398
InternalError: Check failed: undefined.size() == 0 (2 vs. 0) : In PrimFunc default_function variables [ax2, ax3] are used, but are not passed in as API arguments`
I also applied your fix to tvm, and when I re-compile the model, new error appeared as follows:
`build
relax.build(
File "/workspace/tvm/python/tvm/relax/vm_build.py", line 335, in build
mod = pipeline(mod)
File "/workspace/tvm/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
File "/workspace/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
File "tvm/_ffi/_cython/core.cpp", line 7494, in __pyx_f_3tvm_4_ffi_4_cy3_4core_tvm_callback
TVMAPISetLastPythonError(((void *)__pyx_v_err));
File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
File "/mnt/volumes/jointmodel/songtianchen/mlc-llm-dev-eagle-lpai/python/mlc_llm/compiler_pass/pipeline.py", line 181, in _pipeline
mod = seq(mod)
File "/workspace/tvm/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
File "tvm/_ffi/_cython/core.cpp", line 7494, in __pyx_f_3tvm_4_ffi_4_cy3_4core_tvm_callback
TVMAPISetLastPythonError(((void *)__pyx_v_err));
File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
File "/workspace/tvm/python/tvm/ir/transform.py", line 307, in _pass_func
return inst.transform_module(mod, ctx)
File "/workspace/tvm/python/tvm/dlight/base/transform.py", line 71, in transform_module
sch = _apply_rules(func, target, self.rules, tunable=False)
File "/workspace/tvm/python/tvm/dlight/base/transform.py", line 87, in _apply_rules
space = rule.apply(func, target, tunable)
File "/workspace/tvm/python/tvm/dlight/gpu/general_reduction.py", line 114, in apply
sch.compute_at(block, bx, preserve_unit_loops=True)
File "/workspace/tvm/python/tvm/tir/schedule/_type_checker.py", line 340, in wrap
return func(*args, **kwargs)
File "/workspace/tvm/python/tvm/tir/schedule/schedule.py", line 2111, in compute_at
_ffi_api.ScheduleComputeAt( # type: ignore # pylint: disable=no-member
File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 277, in tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
File "/workspace/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm.tir.schedule.schedule.ScheduleError: Traceback (most recent call last):
1: tvm::tir::TracedScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool, int)
at /workspace/tvm/src/tir/schedule/traced_schedule.cc:489
0: tvm::tir::ConcreteScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool, int)
at /workspace/tvm/src/tir/schedule/concrete_schedule.cc:790
ScheduleError: An error occurred in the schedule primitive 'compute-at'.
The IR with diagnostic is:
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func(private=True)
def main(var_reshape240: T.handle, var_adaptive_pool_avg: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
reshape240 = T.match_buffer(var_reshape240, (T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16")
adaptive_pool_avg = T.match_buffer(var_adaptive_pool_avg, (T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")
# tir.Block#0
with T.block("root"):
^^^^^^^^^^^^^^^^^^^^^
T.reads()
^^^^^^^^^
T.writes()
^^^^^^^^^^
adaptive_pool_sum_shared = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16", scope="shared")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0 in range(T.int64(2)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax1 in range(T.int64(1024)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax2 in range(T.int64(12)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax3 in range(T.int64(30)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax4 in range(T.Select((ax2_1 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2_1 * T.int64(16) // T.int64(12)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ax2_1 = T.int64()
^^^^^^^^^^^^^^^^^
for ax5 in range(T.Select((ax3_1 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3_1 * T.int64(40) // T.int64(30)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ax3_1 = T.int64()
^^^^^^^^^^^^^^^^^
with T.block("adaptive_pool_sum"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0 = T.axis.spatial(T.int64(2), ax0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1 = T.axis.spatial(T.int64(1024), ax1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v2 = T.axis.spatial(T.int64(12), ax2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v3 = T.axis.spatial(T.int64(30), ax3)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v4 = T.axis.reduce(T.Select((ax2_1 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2_1 * T.int64(16) // T.int64(12), ax4)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v5 = T.axis.reduce(T.Select((ax3_1 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3_1 * T.int64(40) // T.int64(30), ax5)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(adaptive_pool_sum_shared[v0, v1, v2, v3])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.init():
^^^^^^^^^^^^^^
adaptive_pool_sum_shared[v0, v1, v2, v3] = T.float16(0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
adaptive_pool_sum_shared[v0, v1, v2, v3] = adaptive_pool_sum_shared[v0, v1, v2, v3] + reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax0_ax1_ax2_ax3_fused in T.thread_binding(T.int64(737280), thread="blockIdx.x"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax4 in range(T.int64(1)):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax5_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
for ax5_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.block("adaptive_pool_avg"):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v0 = T.axis.spatial(T.int64(2), ax0_ax1_ax2_ax3_fused // T.int64(368640))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v1 = T.axis.spatial(T.int64(1024), ax0_ax1_ax2_ax3_fused % T.int64(368640) // T.int64(360))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v2 = T.axis.spatial(T.int64(12), ax0_ax1_ax2_ax3_fused % T.int64(360) // T.int64(30))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v3 = T.axis.spatial(T.int64(30), ax0_ax1_ax2_ax3_fused % T.int64(30))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v4 = T.axis.spatial(T.int64(1), ax4)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v5 = T.axis.spatial(T.int64(1), ax5_0 * T.int64(256) + ax5_1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.where(ax5_0 * T.int64(256) + ax5_1 < T.int64(1))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(adaptive_pool_sum_shared[v0, v1, v2, v3])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(adaptive_pool_avg[v0, v1, v2, v3])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
adaptive_pool_avg[v0, v1, v2, v3] = adaptive_pool_sum_shared[v0, v1, v2, v3] / (T.Cast("float16", T.Select((v2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (T.Cast("int64", v2) * T.int64(16) + T.int64(16)) // T.int64(12), (T.Cast("int64", v2) * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - T.Cast("int64", v2) * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (T.Cast("int64", v3) * T.int64(40) + T.int64(40)) // T.int64(30), (T.Cast("int64", v3) * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - T.Cast("int64", v3) * T.int64(40) // T.int64(30)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Error message: The scope tir.Block#0 is not a stage pipeline.
Definition of a scope that is a stage pipeline:
- The region cover property holds for every of its child blocks
- No write-after-read dependency or opaque dependency,
- only read-after-write and write-after-write are allowed
- All the statements in the scope are schedulable statements, i.e. Block and For`
I also printed out the generated TIR by _DebugDump in compile pass, and I compared it with the code in the test code you provided, I found they are almost the same. The difference is in my TIR, there are some cast to int64 like T.Cast("int64", v_ax3)
:
`@T.prim_func(private=True)
def adaptive_avg_pool2d(reshape72: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")):
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
# with T.block("root"):
adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
for rv0, rv1 in T.grid(T.Select((ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2 * T.int64(16) // T.int64(12), T.Select((ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3 * T.int64(40) // T.int64(30)):
with T.block("adaptive_pool_sum"):
v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
T.reads(reshape72[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1])
T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
with T.init():
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(0)
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + reshape72[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
with T.block("adaptive_pool_avg"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float16", T.Select((v_ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (T.Cast("int64", v_ax2) * T.int64(16) + T.int64(16)) // T.int64(12), (T.Cast("int64", v_ax2) * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - T.Cast("int64", v_ax2) * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v_ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (T.Cast("int64", v_ax3) * T.int64(40) + T.int64(40)) // T.int64(30), (T.Cast("int64", v_ax3) * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - T.Cast("int64", v_ax3) * T.int64(40) // T.int64(30)))`
And in my new error message, the TIR printed out in terminal shows there are still undefined variable usage for ax2_1
and ax3_1
. And the TIR is different from what I got from _DebugDump above.
Could you help look into further on this issue? Many thanks!
When I tried to reproduce the error with your test code, I met following error, which is somehow similar but not exactly the same to what I met when compiling the model.
Hi, thanks for the follow up information. Could you help to provide the compile script from relax? @SimonSongg
When I tried to reproduce the error with your test code, I met following error, which is somehow similar but not exactly the same to what I met when compiling the model.
Hi, thanks for the follow up information. Could you help to provide the compile script from relax? @SimonSongg
Hi,
def _build_default():
def build(mod: IRModule, args: "CompileArgs", pipeline=None):
output = args.output
if output.suffix in [".tar", ".lib"]:
system_lib = True
elif output.suffix in [".so", ".dylib", ".dll"]:
system_lib = False
else:
logger.warning("Unknown output suffix: %s. Assuming shared library.", output.suffix)
system_lib = False
mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=system_lib)
relax.build(
mod,
target=args.target,
pipeline=pipeline,
system_lib=system_lib,
).export_library(
str(output),
)
return build
Here is the code mlc-llm used to compile model.
The pipeline is:
@tvm.transform.module_pass(opt_level=0)
def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:
seq = tvm.transform.Sequential(
[
# Phase 0. Add additional information for compilation and remove unused Relax func
DispatchKVCacheCreation(target, flashinfer, metadata),
# AttachSoftmaxWithTemperature(target),
AttachVariableBounds(variable_bounds),
AttachCUDAGraphSymbolicCaptureHints(cuda_graph_symbolic_capture_hints),
AttachLogitProcessFunc(target),
AttachAdditionalPrimFuncs(additional_tirs),
AttachAllocEmbeddingTensorFunc(metadata),
# AttachGPUSamplingFunc(target, variable_bounds),
AttachSpecDecodeAuxFuncs(tensor_parallel_shards),
AttachMemoryPlanAttr(),
tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False)),
_DebugDump("debug-phase0.py", debug_dump, show_meta=False),
# Phase 1. Passes on high-level operator graph
_LogProgress("Running TVM Relax graph-level optimizations"),
FuseFTDequantizeEpilogue(),
FuseDequantizeTranspose(),
CublasDispatch() if cublas_gemm else tvm.transform.Sequential([]),
FuseAddRMSNorm(target=target),
FuseTransposeMatmul(),
_DebugDump("debug-phase1.py", debug_dump, show_meta=False),
# Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
_LogProgress("Lowering to TVM TIR kernels"),
tvm.relax.backend.DispatchSortScan(),
tvm.relax.transform.LegalizeOps(),
tvm.relax.transform.AnnotateTIROpPattern(),
tvm.relax.transform.FoldConstant(),
tvm.relax.transform.FuseOps(),
tvm.relax.transform.FuseTIR(),
# _DebugDump("debug-phase2.py", debug_dump, show_meta=False),
# Phase 3. Passes on TIR
_LogProgress("Running TVM TIR-level optimizations"),
FuseDequantizeMatmulEwise(),
FuseDequantizeTake(),
tvm.relax.transform.DeadCodeElimination(),
CleanUpTIRAttrs(["op_pattern"]),
# _DebugDump("debug-phase3.py", debug_dump, show_meta=False),
# Phase 4. Low-level Optimizations
_LogProgress("Running TVM Dlight low-level optimizations"),
LowBatchGemvSpecialize(),
dl.ApplyDefaultSchedule(
dl.gpu.Matmul(),
dl.gpu.GEMV(),
dl.gpu.Reduction(),
dl.gpu.GeneralReduction(),
dl.gpu.Fallback(),
),
_DebugDump("debug-phase4.py", debug_dump, show_meta=False),
_LogProgress("Lowering to VM bytecode"),
LiftTIRGlobalBufferAlloc(),
(
tvm.tir.transform.ForceNarrowIndexToInt32()
if target.kind.name != "cuda"
else tvm.transform.Sequential([])
),
ScatterTupleGetItem(),
tvm.relax.transform.RewriteDataflowReshape(),
tvm.relax.transform.ToNonDataflow(),
tvm.relax.transform.RemovePurityChecking(),
tvm.relax.transform.CallTIRRewrite(),
(
tvm.relax.transform.IPCAllReduceRewrite(allreduce_strategy)
if allreduce_strategy != IPCAllReduceStrategyType.NONE
else tvm.transform.Sequential([])
),
tvm.relax.transform.StaticPlanBlockMemory(),
AttachMetadataWithMemoryUsage(metadata),
tvm.relax.transform.RewriteCUDAGraph(),
tvm.relax.transform.LowerGPUIPCAllocStorage(),
tvm.relax.transform.LowerAllocTensor(),
tvm.relax.transform.KillAfterLastUse(),
tvm.relax.transform.VMBuiltinLower(),
tvm.relax.transform.VMShapeLower(),
tvm.relax.transform.AttachGlobalSymbol(),
_DebugDump("debug-final.py", debug_dump, show_meta=False),
_LogProgress("Compiling external modules"),
tvm.relax.transform.AttachExternModules(ext_mods),
_LogProgress("Compilation complete! Exporting to disk"),
]
)
mod = seq(mod)
return mod
Looks like the error occurs at dl.gpu.GeneralReduction()
Thanks!
When I tried to reproduce the error with your test code, I met following error, which is somehow similar but not exactly the same to what I met when compiling the model.
Hi, thanks for the follow up information. Could you help to provide the compile script from relax? @SimonSongg
Hi, sorry for bothering @wrongtest-intellif , but I have a new finding.
In tvm/dlight/gpu/general_reduction.py
, I printed out the ir module by:
sch = tir.Schedule(func)
print("===========================")
sch.show()
block_infos = normalize_prim_func(sch)
print("===========================")
sch.show()
And I found normalize_prim_func()
change the prim_func, which made the problem https://discuss.tvm.apache.org/t/compilation-error-for-adaptive-avg-pool2d-relax-op-in-mlc-llm/17784 solved by you occurred again!!
Before normalize_prim_func
:
@I.ir_module
class Module:
@T.prim_func(private=True)
def main(reshape240: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
for rv0, rv1 in T.grid(T.Select((ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2 * T.int64(16) // T.int64(12), T.Select((ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3 * T.int64(40) // T.int64(30)):
with T.block("adaptive_pool_sum"):
v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
T.reads(reshape240[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1])
T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
with T.init():
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(0)
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + reshape240[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
with T.block("adaptive_pool_avg"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float16", T.Select((v_ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (v_ax2 * T.int64(16) + T.int64(16)) // T.int64(12), (v_ax2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - v_ax2 * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v_ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (v_ax3 * T.int64(40) + T.int64(40)) // T.int64(30), (v_ax3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - v_ax3 * T.int64(40) // T.int64(30)))
After normalize_prim_func
:
@I.ir_module
class Module:
@T.prim_func(private=True)
def main(reshape240: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30), T.Select((ax2_1 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2_1 * T.int64(16) // T.int64(12), T.Select((ax3_1 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3_1 * T.int64(40) // T.int64(30)):
ax2_1 = T.int64() ######## HERE ########
ax3_1 = T.int64() ######## HERE ########
with T.block("adaptive_pool_sum"):
v0, v1, v2, v3, v4, v5 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, ax4, ax5])
T.reads(reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5])
T.writes(adaptive_pool_sum[v0, v1, v2, v3])
with T.init():
adaptive_pool_sum[v0, v1, v2, v3] = T.float16(0)
adaptive_pool_sum[v0, v1, v2, v3] = adaptive_pool_sum[v0, v1, v2, v3] + reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
with T.block("adaptive_pool_avg"):
v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(adaptive_pool_sum[v0, v1, v2, v3])
T.writes(adaptive_pool_avg[v0, v1, v2, v3])
T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
adaptive_pool_avg[v0, v1, v2, v3] = adaptive_pool_sum[v0, v1, v2, v3] / (T.Cast("float16", T.Select((v2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (v2 * T.int64(16) + T.int64(16)) // T.int64(12), (v2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - v2 * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (v3 * T.int64(40) + T.int64(40)) // T.int64(30), (v3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - v3 * T.int64(40) // T.int64(30)))
Could you give some clues about how to fix this bug? Thanks you very much!
When I tried to reproduce the error with your test code, I met following error, which is somehow similar but not exactly the same to what I met when compiling the model.
Hi, thanks for the follow up information. Could you help to provide the compile script from relax? @SimonSongg
Hi, sorry for bothering @wrongtest-intellif , but I have a new finding.
In
tvm/dlight/gpu/general_reduction.py
, I printed out the ir module by:sch = tir.Schedule(func) print("===========================") sch.show() block_infos = normalize_prim_func(sch) print("===========================") sch.show()
And I found
normalize_prim_func()
change the prim_func, which made the problem https://discuss.tvm.apache.org/t/compilation-error-for-adaptive-avg-pool2d-relax-op-in-mlc-llm/17784 solved by you occurred again!!Before
normalize_prim_func
:@I.ir_module class Module: @T.prim_func(private=True) def main(reshape240: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16") for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)): for rv0, rv1 in T.grid(T.Select((ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2 * T.int64(16) // T.int64(12), T.Select((ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3 * T.int64(40) // T.int64(30)): with T.block("adaptive_pool_sum"): v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1]) T.reads(reshape240[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1]) T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) with T.init(): adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(0) adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + reshape240[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)): with T.block("adaptive_pool_avg"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3]) T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"}) adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float16", T.Select((v_ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (v_ax2 * T.int64(16) + T.int64(16)) // T.int64(12), (v_ax2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - v_ax2 * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v_ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (v_ax3 * T.int64(40) + T.int64(40)) // T.int64(30), (v_ax3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - v_ax3 * T.int64(40) // T.int64(30)))
After
normalize_prim_func
:@I.ir_module class Module: @T.prim_func(private=True) def main(reshape240: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16") for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30), T.Select((ax2_1 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2_1 * T.int64(16) // T.int64(12), T.Select((ax3_1 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3_1 * T.int64(40) // T.int64(30)): ax2_1 = T.int64() ######## HERE ######## ax3_1 = T.int64() ######## HERE ######## with T.block("adaptive_pool_sum"): v0, v1, v2, v3, v4, v5 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, ax4, ax5]) T.reads(reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5]) T.writes(adaptive_pool_sum[v0, v1, v2, v3]) with T.init(): adaptive_pool_sum[v0, v1, v2, v3] = T.float16(0) adaptive_pool_sum[v0, v1, v2, v3] = adaptive_pool_sum[v0, v1, v2, v3] + reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)): with T.block("adaptive_pool_avg"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(adaptive_pool_sum[v0, v1, v2, v3]) T.writes(adaptive_pool_avg[v0, v1, v2, v3]) T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"}) adaptive_pool_avg[v0, v1, v2, v3] = adaptive_pool_sum[v0, v1, v2, v3] / (T.Cast("float16", T.Select((v2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (v2 * T.int64(16) + T.int64(16)) // T.int64(12), (v2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - v2 * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (v3 * T.int64(40) + T.int64(40)) // T.int64(30), (v3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - v3 * T.int64(40) // T.int64(30)))
Could you give some clues about how to fix this bug? Thanks you very much!
This is because except the primfunc creation, the schedule system and auto-schedule rules also has potential issues to not correctly take loop carried dependency into consideration. Since they are generally developed to optimize static shape workloads.
We are trying to find proper way out for such workloads.
When I tried to reproduce the error with your test code, I met following error, which is somehow similar but not exactly the same to what I met when compiling the model.
Hi, thanks for the follow up information. Could you help to provide the compile script from relax? @SimonSongg
Hi, sorry for bothering @wrongtest-intellif , but I have a new finding. In
tvm/dlight/gpu/general_reduction.py
, I printed out the ir module by:sch = tir.Schedule(func) print("===========================") sch.show() block_infos = normalize_prim_func(sch) print("===========================") sch.show()
And I found
normalize_prim_func()
change the prim_func, which made the problem https://discuss.tvm.apache.org/t/compilation-error-for-adaptive-avg-pool2d-relax-op-in-mlc-llm/17784 solved by you occurred again!! Beforenormalize_prim_func
:@I.ir_module class Module: @T.prim_func(private=True) def main(reshape240: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16") for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)): for rv0, rv1 in T.grid(T.Select((ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2 * T.int64(16) // T.int64(12), T.Select((ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3 * T.int64(40) // T.int64(30)): with T.block("adaptive_pool_sum"): v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1]) T.reads(reshape240[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1]) T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) with T.init(): adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(0) adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + reshape240[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)): with T.block("adaptive_pool_avg"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3]) T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"}) adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float16", T.Select((v_ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (v_ax2 * T.int64(16) + T.int64(16)) // T.int64(12), (v_ax2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - v_ax2 * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v_ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (v_ax3 * T.int64(40) + T.int64(40)) // T.int64(30), (v_ax3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - v_ax3 * T.int64(40) // T.int64(30)))
After
normalize_prim_func
:@I.ir_module class Module: @T.prim_func(private=True) def main(reshape240: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16") for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30), T.Select((ax2_1 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2_1 * T.int64(16) // T.int64(12), T.Select((ax3_1 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3_1 * T.int64(40) // T.int64(30)): ax2_1 = T.int64() ######## HERE ######## ax3_1 = T.int64() ######## HERE ######## with T.block("adaptive_pool_sum"): v0, v1, v2, v3, v4, v5 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, ax4, ax5]) T.reads(reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5]) T.writes(adaptive_pool_sum[v0, v1, v2, v3]) with T.init(): adaptive_pool_sum[v0, v1, v2, v3] = T.float16(0) adaptive_pool_sum[v0, v1, v2, v3] = adaptive_pool_sum[v0, v1, v2, v3] + reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5] for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)): with T.block("adaptive_pool_avg"): v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(adaptive_pool_sum[v0, v1, v2, v3]) T.writes(adaptive_pool_avg[v0, v1, v2, v3]) T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"}) adaptive_pool_avg[v0, v1, v2, v3] = adaptive_pool_sum[v0, v1, v2, v3] / (T.Cast("float16", T.Select((v2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (v2 * T.int64(16) + T.int64(16)) // T.int64(12), (v2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - v2 * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (v3 * T.int64(40) + T.int64(40)) // T.int64(30), (v3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - v3 * T.int64(40) // T.int64(30)))
Could you give some clues about how to fix this bug? Thanks you very much!
This is because except the primfunc creation, the schedule system and auto-schedule rules also has potential issues to not correctly take loop carried dependency into consideration. Since they are generally developed to optimize static shape workloads.
We are trying to find proper way out for such workloads.
Thank you very much for your reply!
cc @Hzfengsy can you help to review this PR
Here is some updates for new change. Because TE could define axis with it's domain depend on previous axes. It is a problem to convert such compute op to one single block since the block iter vars should be insensitive to their relative positions defined.
It seems better to create nested block levels to represent such workloads. And ensure each level's block iter vars are independent. For the adaptive pooling case in the context, changed create_primfunc
would generate as below
This decomposition could ensure independency for block vars.
@T.prim_func
def tir_workload(x: T.Buffer((1, 1024, 16, 40), "float32"), adaptive_pool_avg: T.Buffer((1, 1024, 12, 30), "float32")):
T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"})
adaptive_pool_sum = T.alloc_buffer((1, 1024, 12, 30))
for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30):
with T.block("adaptive_pool_sum_1"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(x[v_ax0, v_ax1, v_ax2 * 16 // 12:v_ax2 * 16 // 12 + ((v_ax2 % 3 * 4 + 16) // 12 + 1), v_ax3 * 40 // 30:v_ax3 * 40 // 30 + ((v_ax3 % 3 * 10 + 40) // 30 + 1)])
T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
for rv0, rv1 in T.grid(T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12, T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30):
with T.block("adaptive_pool_sum"):
v_ax0_1 = T.axis.spatial((v_ax0, v_ax0 + 1), v_ax0)
v_ax1_1 = T.axis.spatial((v_ax1, v_ax1 + 1), v_ax1)
v_ax2_1 = T.axis.spatial((v_ax2, v_ax2 + 1), v_ax2)
v_ax3_1 = T.axis.spatial((v_ax3, v_ax3 + 1), v_ax3)
v_rv0, v_rv1 = T.axis.remap("RR", [rv0, rv1])
T.reads(x[v_ax0_1, v_ax1_1, v_ax2_1 * 16 // 12 + v_rv0, v_ax3_1 * 40 // 30 + v_rv1])
T.writes(adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1])
with T.init():
adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] = T.float32(0.0)
adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] = adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] + x[v_ax0_1, v_ax1_1, v_ax2_1 * 16 // 12 + v_rv0, v_ax3_1 * 40 // 30 + v_rv1]
for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30):
with T.block("adaptive_pool_avg"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12) * T.Cast("float32", T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30))
I also try some tune methods:
general_reduction
, failed. I think that is because the outer level of block is not reduction block now.DisallowDynamicLoop
, and force skip the inner block in space generation with f_block_filter
. It could produce correct and optimized results. The trace example could be def apply_trace(sch: tir.Schedule) -> None:
b0 = sch.get_block(name="adaptive_pool_sum_l1", func_name="main")
b1 = sch.get_block(name="adaptive_pool_avg", func_name="main")
b2 = sch.get_block(name="root", func_name="main")
sch.unannotate(block_or_loop=b1, ann_key="schedule_rule")
v3 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=4)
sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v3)
l4, l5, l6, l7 = sch.get_loops(block=b1)
l8 = sch.fuse(l4, l5, l6, l7, preserve_unit_iters=True)
l9, l10, l11 = sch.split(loop=l8, factors=[None, 256, 1024], preserve_unit_iters=True, disable_predication=False)
sch.reorder(l10, l11, l9)
sch.bind(loop=l10, thread_axis="blockIdx.x")
sch.bind(loop=l11, thread_axis="threadIdx.x")
l12, l13, l14, l15 = sch.get_loops(block=b0)
l16 = sch.fuse(l12, l13, l14, l15, preserve_unit_iters=True)
l17, l18, l19 = sch.split(loop=l16, factors=[None, 256, 1024], preserve_unit_iters=True, disable_predication=False)
sch.reorder(l18, l19, l17)
sch.bind(loop=l18, thread_axis="blockIdx.x")
sch.bind(loop=l19, thread_axis="threadIdx.x")
sch.enter_postproc()
b20 = sch.get_block(name="root", func_name="main")
sch.unannotate(block_or_loop=b20, ann_key="meta_schedule.unroll_explicit")
b21, b22 = sch.get_child_blocks(b20)
l23, l24, l25 = sch.get_loops(block=b21)
l26, l27, l28 = sch.get_loops(block=b22)
b29 = sch.get_block(name="adaptive_pool_sum", func_name="main")
l30, l31 = sch.get_loops(block=b29)
b32 = sch.decompose_reduction(block=b29, loop=l30)
If the loop domain depends on other loops, currently there is missing transformations in
CreatePrimFunc
, which lead to undefined variables in lowering.https://discuss.tvm.apache.org/t/compilation-error-for-adaptive-avg-pool2d-relax-op-in-mlc-llm/17784