apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.75k stars 3.47k forks source link

[Bug][MetaSchedule] Failed to tune fp16 dense_add workload of some shapes on cuda #14137

Closed wllqwzx closed 1 year ago

wllqwzx commented 1 year ago

I found that when tuning the fp16 tensorcore dense_add kernel, the tuning fails on some shapes and the reported error is non-deterministic.

For example, when the workload is N=1, M=1000, K=512, the tuning fails.

There are two kinds of reported errors. From my observation, the following error may be reported more frequently:

Click me ``` 2023-02-27 14:11:46 [INFO] Logging directory: /tmp/tmp71o3_ldv/logs 2023-02-27 14:11:46 [INFO] LocalBuilder: max_workers = 11 2023-02-27 14:11:47 [INFO] LocalRunner: max_workers = 1 2023-02-27 14:11:48 [INFO] [task_scheduler.cc:159] Initializing Task #0: "main" 2023-02-27 14:11:48 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main" Traceback (most recent call last): File "bug_tune_dense_add.py", line 507, in test_tune_tir_matmul_cuda_tensor_core() File "bug_tune_dense_add.py", line 195, in test_tune_tir_matmul_cuda_tensor_core database = tune_tir( File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tir_integration.py", line 104, in tune_tir return tune_tasks( File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tune.py", line 117, in tune_tasks task_scheduler.tune( File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/task_scheduler/task_scheduler.py", line 132, in tune _ffi_api.TaskSchedulerTune( # type: ignore # pylint: disable=no-member File "/mnt/disk5/wll/code/metaschedule/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__ raise get_last_ffi_error() tvm._ffi.base.TVMError: Traceback (most recent call last): 6: TVMFuncCall 5: _ZN3tvm7runtime13PackedF 4: tvm::runtime::TypedPackedFunc, tvm::runtime::Array, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array, tvm::runtime::Optional, tvm::runtime::Optional)>::AssignTypedLambda, tvm::runtime::Array, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array, tvm::runtime::Optional, tvm::runtime::Optional, void>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array, tvm::runtime::Array, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array, tvm::runtime::Optional, tvm::runtime::Optional))::{lambda(tvm::meta_schedule::TaskScheduler, tvm::runtime::Array, tvm::runtime::Array, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array, tvm::runtime::Optional, tvm::runtime::Optional)#1}>(tvm::runtime::Registry::set_body_method, tvm::runtime::Array, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array, tvm::runtime::Optional, tvm::runtime::Optional, void>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array, tvm::runtime::Array, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array, tvm::runtime::Optional, tvm::runtime::Optional))::{lambda(tvm::meta_schedule::TaskScheduler, tvm::runtime::Array, tvm::runtime::Array, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array, tvm::runtime::Optional, tvm::runtime::Optional)#1}, std::__cxx11::basic_string, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const 3: tvm::meta_schedule::TaskSchedulerNode::Tune(tvm::runtime::Array, tvm::runtime::Array, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array, tvm::runtime::Optional, tvm::runtime::Optional) 2: tvm::meta_schedule::ReplayTraceNode::GenerateMeasureCandidates() 1: tvm::meta_schedule::ReplayTraceNode::State::GenerateMeasureCandidates() 0: tvm::support::parallel_for_dynamic(int, int, int, std::function const&) File "/mnt/disk5/wll/code/metaschedule/src/support/parallel_for.cc", line 128 RuntimeError: parallel_for_dynamic error with ScheduleError: (not rendered) ```

and may report this error with a lower frequency:

Click me ``` 2023-02-27 14:20:13 [INFO] Logging directory: /tmp/tmputfxvrl5/logs 2023-02-27 14:20:13 [INFO] LocalBuilder: max_workers = 11 2023-02-27 14:20:14 [INFO] LocalRunner: max_workers = 1 2023-02-27 14:20:15 [INFO] [task_scheduler.cc:159] Initializing Task #0: "main" Traceback (most recent call last): File "bug_tune_dense_add.py", line 507, in test_tune_tir_matmul_cuda_tensor_core() File "bug_tune_dense_add.py", line 195, in test_tune_tir_matmul_cuda_tensor_core database = tune_tir( File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tir_integration.py", line 104, in tune_tir return tune_tasks( File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/tune.py", line 117, in tune_tasks task_scheduler.tune( File "/mnt/disk5/wll/code/metaschedule/python/tvm/meta_schedule/task_scheduler/task_scheduler.py", line 132, in tune _ffi_api.TaskSchedulerTune( # type: ignore # pylint: disable=no-member File "/mnt/disk5/wll/code/metaschedule/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__ raise get_last_ffi_error() tvm.tir.schedule.schedule.ScheduleError: Traceback (most recent call last): 9: TVMFuncCall 8: _ZN3tvm7runtime13PackedF 7: tvm::runtime::TypedPackedFunc, tvm::runtime::Array, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array, tvm::runtime::Optional, tvm::runtime::Optional)>::AssignTypedLambda, tvm::runtime::Array, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array, tvm::runtime::Optional, tvm::runtime::Optional, void>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array, tvm::runtime::Array, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array, tvm::runtime::Optional, tvm::runtime::Optional))::{lambda(tvm::meta_schedule::TaskScheduler, tvm::runtime::Array, tvm::runtime::Array, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array, tvm::runtime::Optional, tvm::runtime::Optional)#1}>(tvm::runtime::Registry::set_body_method, tvm::runtime::Array, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array, tvm::runtime::Optional, tvm::runtime::Optional, void>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array, tvm::runtime::Array, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array, tvm::runtime::Optional, tvm::runtime::Optional))::{lambda(tvm::meta_schedule::TaskScheduler, tvm::runtime::Array, tvm::runtime::Array, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array, tvm::runtime::Optional, tvm::runtime::Optional)#1}, std::__cxx11::basic_string, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const 6: tvm::meta_schedule::TaskSchedulerNode::Tune(tvm::runtime::Array, tvm::runtime::Array, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array, tvm::runtime::Optional, tvm::runtime::Optional) 5: tvm::meta_schedule::PostOrderApplyNode::GenerateDesignSpace(tvm::IRModule const&) 4: tvm::meta_schedule::MultiLevelTilingTensorCoreNode::Apply(tvm::tir::Schedule const&, tvm::tir::BlockRV const&) 3: tvm::meta_schedule::MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector >) 2: tvm::meta_schedule::MultiLevelTilingNode::AddReadReuse(tvm::meta_schedule::State) const 1: tvm::tir::TracedScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool, int) 0: tvm::tir::ConcreteScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool, int) ScheduleError: An error occurred in the schedule primitive 'compute-at'. The IR with diagnostic is: # from tvm.script import ir as I # from tvm.script import tir as T @I.ir_module class Module: @T.prim_func def main(p0_handle: T.handle, p1_handle: T.handle, p2_handle: T.handle, T_add_handle: T.handle): T.func_attr({"global_symbol": "main", "layout_free_buffers": [1], "tir.noalias": True}) p0 = T.match_buffer(p0_handle, (T.int64(8), T.int64(512)), "float16") p1 = T.match_buffer(p1_handle, (T.int64(1000), T.int64(512)), "float16") p2 = T.match_buffer(p2_handle, (T.int64(8), T.int64(1000)), "float16") T_add = T.match_buffer(T_add_handle, (T.int64(8), T.int64(1000)), "float16") # tir.Block#0 with T.block("root"): ^^^^^^^^^^^^^^^^^^^^^ T.reads() ^^^^^^^^^ T.writes() ^^^^^^^^^^ T_matmul_NT = T.alloc_buffer((T.int64(8), T.int64(1000)), "float16") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ p0_reindex = T.alloc_buffer((T.int64(16), T.int64(512)), "float16") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ p1_reindex = T.alloc_buffer((T.int64(1008), T.int64(512)), "float16") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T_matmul_NT_reindex_shared_dyn = T.alloc_buffer((T.int64(16), T.int64(1008)), "float16", scope="shared.dyn") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T_matmul_NT_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((T.int64(16), T.int64(1008)), "float16", scope="wmma.accumulator") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ p0_reindex_shared_dyn = T.alloc_buffer((T.int64(16), T.int64(512)), "float16", scope="shared.dyn") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax0 in range(T.int64(16)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax1 in range(T.int64(512)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with T.block("p0_reindex_reindex"): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v0 = T.axis.spatial(T.int64(16), ax0) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v1 = T.axis.spatial(T.int64(512), ax1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.reads(p0[v0, v1]) ^^^^^^^^^^^^^^^^^^^ T.writes(p0_reindex[v0, v1]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ p0_reindex[v0, v1] = T.if_then_else(v0 < T.int64(8), p0[v0, v1], T.float16(0)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax0 in range(T.int64(1008)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax1 in range(T.int64(512)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with T.block("p1_reindex_reindex"): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v0 = T.axis.spatial(T.int64(1008), ax0) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v1 = T.axis.spatial(T.int64(512), ax1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.reads(p1[v0, v1]) ^^^^^^^^^^^^^^^^^^^ T.writes(p1_reindex[v0, v1]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ p1_reindex[v0, v1] = T.if_then_else(v0 < T.int64(1000), p1[v0, v1], T.float16(0)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax0 in range(T.int64(16)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax1 in range(T.int64(512)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with T.block("p0_reindex_shared.dyn"): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v0 = T.axis.spatial(T.int64(16), ax0) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v1 = T.axis.spatial(T.int64(512), ax1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.reads(p0_reindex[v0, v1]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.writes(p0_reindex_shared_dyn[v0, v1]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ p0_reindex_shared_dyn[v0, v1] = p0_reindex[v0, v1] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax0_0_0_ax1_0_0_fused in T.thread_binding(T.int64(1), thread="blockIdx.y"): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax0_0_1_ax1_0_1_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax0_0_2_ax1_0_2_fused in T.thread_binding(T.int64(3), thread="threadIdx.y"): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax2_0_0 in range(T.int64(1)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax2_0_1 in range(T.int64(32)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax0_0_3 in range(T.int64(1)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax1_0_3 in range(T.int64(21)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax2_0_2 in range(T.int64(1)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax0_0_4 in range(T.int64(1)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax1_0_4 in range(T.int64(1)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with T.block("T_matmul_NT_o"): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v0_o = T.axis.spatial(T.int64(1), ax0_0_4 + ax0_0_3) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v1_o = T.axis.spatial(T.int64(63), ax1_0_4 + ax0_0_2_ax1_0_2_fused * T.int64(21) + ax1_0_3) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v2_o = T.axis.reduce(T.int64(32), ax2_0_0 * T.int64(32) + ax2_0_1 + ax2_0_2) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.reads(p0_reindex_shared_dyn[T.int64(0):T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)], p1_reindex[v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.writes(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[T.int64(0):T.int64(16), v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16)]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f16_trans", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f16", "meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 1, "warp_execution": 1}) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with T.init(): ^^^^^^^^^^^^^^ for ax0_1 in range(T.int64(16)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax1_1 in range(T.int64(16)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with T.block("T_matmul_NT_init"): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v0_i_init = T.axis.spatial(T.int64(16), ax0_1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v1_i_init = T.axis.spatial(T.int64(16), ax1_1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.reads() ^^^^^^^^^ T.writes(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i_init, v1_o * T.int64(16) + v1_i_init]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i_init, v1_o * T.int64(16) + v1_i_init] = T.float16(0) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax0_1 in range(T.int64(16)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax1_1 in range(T.int64(16)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax2_1 in range(T.int64(16)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with T.block("T_matmul_NT"): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v0_i = T.axis.spatial(T.int64(16), ax0_1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v1_i = T.axis.spatial(T.int64(16), ax1_1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v2_i = T.axis.reduce(T.int64(16), ax2_1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.reads(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i], p0_reindex_shared_dyn[v0_i, v2_o * T.int64(16) + v2_i], p1_reindex[v1_o * T.int64(16) + v1_i, v2_o * T.int64(16) + v2_i]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.writes(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i] = T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i] + p0_reindex_shared_dyn[v0_i, v2_o * T.int64(16) + v2_i] * p1_reindex[v1_o * T.int64(16) + v1_i, v2_o * T.int64(16) + v2_i] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax0_0 in range(T.int64(1)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax1_0 in range(T.int64(21)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with T.block("T_matmul_NT_reindex_shared.dyn_wmma.accumulator_o"): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v0_o = T.axis.spatial(T.int64(1), ax0_0) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v1_o = T.axis.spatial(T.int64(63), ax0_0_2_ax1_0_2_fused * T.int64(21) + ax1_0) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.reads(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[T.int64(0):T.int64(16), v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16)]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.writes(T_matmul_NT_reindex_shared_dyn[T.int64(0):T.int64(16), v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16)]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f16_shared_dyn"}) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax0_1 in range(T.int64(16)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax1_1 in range(T.int64(16)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with T.block("T_matmul_NT_reindex_shared.dyn_wmma.accumulator"): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v0_i = T.axis.spatial(T.int64(16), ax0_1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v1_i = T.axis.spatial(T.int64(16), ax1_1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.reads(T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.writes(T_matmul_NT_reindex_shared_dyn[v0_i, v1_o * T.int64(16) + v1_i]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T_matmul_NT_reindex_shared_dyn[v0_i, v1_o * T.int64(16) + v1_i] = T_matmul_NT_reindex_shared_dyn_wmma_accumulator[v0_i, v1_o * T.int64(16) + v1_i] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax0_ax1_fused in range(T.int64(16128)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with T.block("T_matmul_NT_reindex_shared.dyn"): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v0 = T.axis.spatial(T.int64(16), ax0_ax1_fused // T.int64(1008)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v1 = T.axis.spatial(T.int64(1008), ax0_ax1_fused % T.int64(1008)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.where(ax0_ax1_fused < T.int64(8056) and ax0_ax1_fused % T.int64(1008) < T.int64(1000)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.reads(T_matmul_NT_reindex_shared_dyn[v0, v1]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.writes(T_matmul_NT[v0, v1]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.block_attr({"meta_schedule.cooperative_fetch": 1}) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T_matmul_NT[v0, v1] = T_matmul_NT_reindex_shared_dyn[v0, v1] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax0 in range(T.int64(8)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ for ax1 in range(T.int64(1000)): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with T.block("T_add"): ^^^^^^^^^^^^^^^^^^^^^^ v_ax0 = T.axis.spatial(T.int64(8), ax0) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v_ax1 = T.axis.spatial(T.int64(1000), ax1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.reads(T_matmul_NT[v_ax0, v_ax1], p2[v_ax0, v_ax1]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.writes(T_add[v_ax0, v_ax1]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T_add[v_ax0, v_ax1] = T_matmul_NT[v_ax0, v_ax1] + p2[v_ax0, v_ax1] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Error message: The scope tir.Block#0 is not a stage pipeline. Definition of a scope that is a stage pipeline: - The region cover property holds for every of its child blocks - No write-after-read dependency or opaque dependency, - only read-after-write and write-after-write are allowed - All the statements in the scope are schedulable statements, i.e. Block and For ```

I tried different N, and found that when N=2, 4, 8, 12, 17, 18, 24 the tuning still fails, but when N=16, 32 it succeeds. I guess it may be because of the alignment requirement of m16n16k16 tensor core.

Expected behavior

The tuning succeeds

Environment

Steps to reproduce

import tempfile
import os
import numpy as np

import tvm
import tvm.tir.tensor_intrin
from tvm import meta_schedule as ms
from tvm.meta_schedule import tune_tir
from tvm.meta_schedule.database import JSONDatabase
from tvm.target import Target
from tvm.tir import Schedule
from tvm.ir.transform import PassContext
from tvm.meta_schedule.testing import te_workload
from tvm import tir
from tvm.script import ir as I
from tvm.script import tir as T

@I.ir_module
class Module0:
    @T.prim_func
    def main(p0: T.Buffer((T.int64(1), T.int64(512)), "float16"), p1: T.Buffer((T.int64(1000), T.int64(512)), "float16"), p2: T.Buffer((T.int64(1), T.int64(1000)), "float16"), T_add: T.Buffer((T.int64(1), T.int64(1000)), "float16")):
        T.func_attr({"global_symbol": "main", "layout_free_buffers": [1], "tir.noalias": True})
        # with T.block("root"):
        T_matmul_NT = T.alloc_buffer((T.int64(1), T.int64(1000)), "float16")
        for i, j, k in T.grid(T.int64(1), T.int64(1000), T.int64(512)):
            with T.block("T_matmul_NT"):
                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                T.reads(p0[v_i, v_k], p1[v_j, v_k])
                T.writes(T_matmul_NT[v_i, v_j])
                with T.init():
                    T_matmul_NT[v_i, v_j] = T.float16(0)
                T_matmul_NT[v_i, v_j] = T_matmul_NT[v_i, v_j] + p0[v_i, v_k] * p1[v_j, v_k]
        for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_matmul_NT[v_ax0, v_ax1], p2[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = T_matmul_NT[v_ax0, v_ax1] + p2[v_ax0, v_ax1]

@I.ir_module
class Module1:
    @T.prim_func
    def main(p0: T.Buffer((T.int64(16), T.int64(512)), "float16"), p1: T.Buffer((T.int64(1000), T.int64(512)), "float16"), p2: T.Buffer((T.int64(16), T.int64(1000)), "float16"), T_add: T.Buffer((T.int64(16), T.int64(1000)), "float16")):
        T.func_attr({"global_symbol": "main", "layout_free_buffers": [1], "tir.noalias": True})
        # with T.block("root"):
        T_matmul_NT = T.alloc_buffer((T.int64(16), T.int64(1000)), "float16")
        for i, j, k in T.grid(T.int64(16), T.int64(1000), T.int64(512)):
            with T.block("T_matmul_NT"):
                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                T.reads(p0[v_i, v_k], p1[v_j, v_k])
                T.writes(T_matmul_NT[v_i, v_j])
                with T.init():
                    T_matmul_NT[v_i, v_j] = T.float16(0)
                T_matmul_NT[v_i, v_j] = T_matmul_NT[v_i, v_j] + p0[v_i, v_k] * p1[v_j, v_k]
        for ax0, ax1 in T.grid(T.int64(16), T.int64(1000)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_matmul_NT[v_ax0, v_ax1], p2[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = T_matmul_NT[v_ax0, v_ax1] + p2[v_ax0, v_ax1]

def tune_dense_add_cuda_tensor_core():
    target = Target("nvidia/nvidia-a100")
    with tempfile.TemporaryDirectory() as work_dir:
        database = ms.database.Database.create(kind="json", work_dir=work_dir)
        mod = Module0   # failed
        # mod = Module1   # success
        database = tune_tir(
            mod=mod,
            target=target,
            work_dir=work_dir,
            num_trials_per_iter=10,
            max_trials_global=10,
            strategy="replay-trace",
            # strategy="evolutionary",
            database=database,
        )
        sch = ms.tir_integration.compile_tir(database, mod, target)
        if sch is None:
            print("No valid schedule found!")
        else:
            from tvm.contrib import nvcc
            import numpy as np
            ctx = tvm.cuda()
            if nvcc.have_tensorcore(ctx.compute_version):
                with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
                    func = tvm.build(sch.mod["main"], [], "cuda")
                    # print(func.imported_modules[0].get_source())
                    # print(sch.mod.script())
                    # print(sch.trace)

if __name__ == "__main__":
    tune_dense_add_cuda_tensor_core()

Triage

cc @ibsidorenko

masahi commented 1 year ago

Can you examine what the function https://github.com/apache/tvm/blob/7fd0cdb230ac58f2311b07a6fbea3ff7cb98aa07/python/tvm/topi/cuda/tensorcore_alter_op.py#L133 is returning for this input? For such shape I expect it to be padded so that it can be consumed by tensorcore.

wllqwzx commented 1 year ago

No, this function seems to be invoked in the relay's Legalize pass, while this input is a prim_func.

Hzfengsy commented 1 year ago

cc @vinx13

vinx13 commented 1 year ago

I'll take a look. There was a previous attempt #14030 to solve this, but it is alleviate most cases but it's not a full solution. The problem is current arithmetic analysis can't not handle arbitrary block predicates, it can only handle simple bounds like min < var or var < max where min/max are constants. We also updated the search space in #14108. I'll double check if this still occurs in the new search space

lileidev commented 1 year ago

I meet with fellow error when trying to reproduce: (base) root@356a70204ac9:/workspace/tvm/debug/issue14137# python main.py error: module 'tvm.target._ffi_api' has no attribute 'llvm_lookup_intrinsic_id' --> /root/anaconda3/lib/python3.9/site-packages/tvm-0.12.dev387+gccc0b9162-py3.9-linux-x86_64.egg/tvm/tir/tensor_intrin/arm_cpu.py:65:13 |
65 | T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"), | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ note: run with TVM_BACKTRACE=1 environment variable to display a backtrace.

Environment: CPU: AMD EPYC 7543 GPU: NVIDIA A100

tvm commit id: ccc0b9162f2e983a8810e99c903c7141dbec81b6

Hzfengsy commented 1 year ago

@lileidev, It looks like you are trying to build an arm module on an x86 CPU, which not works

lileidev commented 1 year ago

I compiled tvm based on default cmake/config.cmake file, didn't specify ARM platform. And you can find that the package path "tvm-0.12.dev387+gccc0b9162-py3.9-linux-x86_64.egg" is x86_64. This error can be produced just by "import tvm.tir.tensor_intrin"

Hzfengsy commented 1 year ago

error: module 'tvm.target._ffi_api' has no attribute 'llvm_lookup_intrinsic_id'

T should be imported by from tvm.script import tir as T, but looks like it becomes tvm.target for some reason. I have no specific idea about it, but the branch works on my env and CI.

lileidev commented 1 year ago

Both Module1 and Module0 can run pass on my machine.

Module1: 2023-03-14 10:05:44 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 0 2023-03-14 10:05:44 [INFO] [task_scheduler.cc:320] ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done

0 | main | 16400000 | 1 | 2356.2847 | 6.9601 | 6.9601 | 10 | Y

Total trials: 10 Total latency (us): 6.96011

Module0: 2023-03-14 10:05:05 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 0 2023-03-14 10:05:05 [INFO] [task_scheduler.cc:320] ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done

0 | main | 1025000 | 1 | 157.0099 | 6.5282 | 6.5282 | 10 | Y

Total trials: 10 Total latency (us): 6.52825