Open XFPlus opened 1 year ago
I understand where the bug is coming from and your workaround, but since we get two kernels anyway, why not create a prim func for each of them? It's odd to see T.launch_thread("threadIdx.x", ...)
with different thread counts in a single prim func.
I understand where the bug is coming from and your workaround, but since we get two kernels anyway, why not create a prim func for each of them? It's odd to see
T.launch_thread("threadIdx.x", ...)
with different thread counts in a single prim func.
Hello @masahi , Thanks for your reply. The bug appeared when I tried to use metaschedule to tune resnet-50, and it picked a conv2d winograd implementation which generates the script like that. So I tried to reproduce the bug using this irregular code.
Here is my original test code:
target = tvm.target.cuda(arch="sm_80", options="-max_threads_per_block 1024 -max_shared_memory_per_block 49152")
# target = "llvm -num-cores 56"
mod, params, input_shape, out_shape = get_network(model_path, dtype)
seq = tvm.transform.Sequential(
[
relay.transform.ToMixedPrecision("float16"),
relay.transform.ConvertLayout({"nn.conv2d": ["NHWC", "HWIO"]}),
]
)
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
print('layout transform mod: ', mod)
mod = partition_for_cutlass(mod)
print('partition mod: ', mod)
# run tuning tasks
print("Tuning...")
# with tempfile.TemporaryDirectory() as work_dir:
backend = 'vm'
if True:
work_dir = logfile
os.makedirs(work_dir, exist_ok=True)
database = ms.database.JSONDatabase(work_dir=work_dir)
with ms.Profiler() as profiler:
if tune:
extracted_tasks = ms.relay_integration.extract_tasks( mod,
target, params,
opt_level=3, # disabled_pass=['AlterOpLayout']
) for i in extracted_tasks:
print(i.task_name) print(i.mod)
tasks, task_weights = ms.relay_integration.extracted_tasks_to_tune_contexts(
extracted_tasks=extracted_tasks, work_dir=work_dir,
)
database = ms.relay_integration.tune_tasks( tasks=tasks,
task_weights=task_weights,
work_dir=work_dir,
database=database,
# config=ms.TuneConfig(strategy='evolutionary', max_trials_global=8 * len(extracted_tasks), max_trials_per_task=8, num_trials_per_iter=16,),
max_trials_global=10000,
# num_trials_per_iter=32,
builder = ms.builder.LocalBuilder(timeout_sec=120),
runner = ms.runner.RPCRunner(
rpc_config=ms.runner.config.RPCConfig(tracker_host='127.0.0.1', tracker_port=9190, tracker_key='a100', session_timeout_sec=30),
max_workers=8,
),
)
else:
database = ms.database.JSONDatabase(work_dir=logfile)
lib = ms.relay_integration.compile_relay(database, mod, target, params, backend=backend)
print(profiler.table())
# compile kernels with history best records
print("Compile...")
# with database, tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_meta_schedule": True}):
# lib = relay.build_module.build(mod, target=target, params=params)
# lib = relay.build(mod, target=target, params=params)
if backend == 'vm':
lib = finalize_modules_vm(lib, "compile.so")
else:
lib = finalize_modules(lib, "compile.so", './tmp')
And the whole error msg:
Traceback (most recent call last):
File "meta-tuner.py", line 231, in <module>
tune_and_evaluate(logfile, tune=True)
File "meta-tuner.py", line 155, in tune_and_evaluate
database = ms.relay_integration.tune_tasks(
File "/home/xfplus/ws/project/tvm/python/tvm/meta_schedule/tune.py", line 118, in tune_tasks
task_scheduler.tune(
File "/home/xfplus/ws/project/tvm/python/tvm/meta_schedule/task_scheduler/task_scheduler.py", line 132, in tune
_ffi_api.TaskSchedulerTune( # type: ignore # pylint: disable=no-member
File "/home/xfplus/ws/project/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
6: TVMFuncCall
5: _ZN3tvm7runtime13PackedFuncObj
4: tvm::runtime::TypedPackedFunc<void (tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime:: Array<tv
m::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule:: TaskSc
hedulerNode, void, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime:: Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>, void>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm:: me
ta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler, tvm:
:runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm:: meta_sched
ule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::TaskScheduler, tvm::meta_schedule::TaskSchedulerNode, void, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime:: Arr
ay<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>, vo
id>(void (tvm::meta_schedule::TaskSchedulerNode::*)(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule:: Measur
eCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>))::{lambda(tvm::meta_schedule::TaskScheduler, tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>
, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule::MeasureCallback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)#1}, std::__cxx11:: basic_s
tring<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
3: tvm::meta_schedule::GradientBasedNode::Tune(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule:: MeasureCa
llback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)
2: tvm::meta_schedule::TaskSchedulerNode::Tune(tvm::runtime::Array<tvm::meta_schedule::TuneContext, void>, tvm::runtime::Array<tvm::FloatImm, void>, int, int, int, tvm::meta_schedule::Builder, tvm::meta_schedule::Runner, tvm::runtime::Array<tvm::meta_schedule:: MeasureCa
llback, void>, tvm::runtime::Optional<tvm::meta_schedule::Database>, tvm::runtime::Optional<tvm::meta_schedule::CostModel>)
1: tvm::meta_schedule::SendToBuilder(tvm::meta_schedule::TaskRecordNode*, tvm::meta_schedule::Builder const&)
0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) [clone .cold]
File "/home/xfplus/ws/project/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
rv = local_pyfunc(*pyargs)
File "/home/xfplus/ws/project/tvm/python/tvm/meta_schedule/utils.py", line 76, in method
return getattr(inst, name)(*args, **kwargs)
File "/home/xfplus/ws/project/tvm/python/tvm/meta_schedule/builder/local_builder.py", line 163, in build
results.append(BuilderResult(_worker_func(self.f_build, self.f_export, build_input.mod, build_input.target, None), None))
File "/home/xfplus/ws/project/tvm/python/tvm/meta_schedule/builder/local_builder.py", line 235, in _worker_func
rt_mod: Module = f_build(mod, target, _deserialize_params(params))
File "/home/xfplus/ws/project/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
raise get_last_ffi_error()
1: TVMFuncCall
0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) [clone .cold]
File "/home/xfplus/ws/project/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
rv = local_pyfunc(*pyargs)
File "/home/xfplus/ws/project/tvm/python/tvm/meta_schedule/builder/local_builder.py", line 265, in default_build
return tvm_build(mod, target=target)
File "/home/xfplus/ws/project/tvm/python/tvm/driver/build_module.py", line 283, in build
rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
File "/home/xfplus/ws/project/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
raise get_last_ffi_error()
207: TVMFuncCall
206: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::Map<tvm::Target,
tvm::IRModule, void, void> const&, tvm::Target)#6}>(tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
205: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
204: tvm::codegen::Build(tvm::IRModule, tvm::Target)
203: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::codegen::{lambda(tvm::IRModule, tvm::Target)#6}>(tvm::codegen::{lambda(tvm::IRModule, tvm:
:Target)#6}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std:: a
llocator<char> >, tvm::runtime::TVMRetValue)
202: tvm::codegen::LLVMModuleNode::Init(tvm::IRModule const&, tvm::Target const&)
201: void tvm::codegen::CodeGenLLVM::AddFunctionsOrdered<__gnu_cxx::__normal_iterator<tvm::tir::PrimFunc*, std::vector<tvm::tir::PrimFunc, std::allocator<tvm::tir::PrimFunc> > >, tvm::codegen::CodeGenLLVM::AddFunctionsOrdered<__gnu_cxx::__normal_iterator<tvm::tir:: PrimF
unc*, std::vector<tvm::tir::PrimFunc, std::allocator<tvm::tir::PrimFunc> > > >(__gnu_cxx::__normal_iterator<tvm::tir::PrimFunc*, std::vector<tvm::tir::PrimFunc, std::allocator<tvm::tir::PrimFunc> > >, __gnu_cxx::__normal_iterator<tvm::tir::PrimFunc*, std::vector<tvm:: tir:
:PrimFunc, std::allocator<tvm::tir::PrimFunc> > >)::{lambda(auto:1)#1}>(__gnu_cxx::__normal_iterator<tvm::tir::PrimFunc*, std::vector<tvm::tir::PrimFunc, std::allocator<tvm::tir::PrimFunc> > >, __gnu_cxx::__normal_iterator<tvm::tir::PrimFunc*, std::vector<tvm::tir:: PrimFu
nc, std::allocator<tvm::tir::PrimFunc> > >, tvm::codegen::CodeGenLLVM::AddFunctionsOrdered<__gnu_cxx::__normal_iterator<tvm::tir::PrimFunc*, std::vector<tvm::tir::PrimFunc, std::allocator<tvm::tir::PrimFunc> > > >(__gnu_cxx::__normal_iterator<tvm::tir::PrimFunc*, std:: vec
tor<tvm::tir::PrimFunc, std::allocator<tvm::tir::PrimFunc> > >, __gnu_cxx::__normal_iterator<tvm::tir::PrimFunc*, std::vector<tvm::tir::PrimFunc, std::allocator<tvm::tir::PrimFunc> > >)::{lambda(auto:1)#1})
200: tvm::codegen::CodeGenCPU::AddFunction(tvm::tir::PrimFunc const&)
199: tvm::codegen::CodeGenLLVM::AddFunctionInternal(tvm::tir::PrimFunc const&, bool)
198: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
197: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
196: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
195: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
194: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
193: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
192: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
191: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
190: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
189: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
188: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
187: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
186: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
185: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
184: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
183: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
182: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
181: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
180: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
179: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
178: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
177: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
176: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
175: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
174: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
173: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
172: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AttrStmtNode const*)
171: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AttrStmtNode const*)
170: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
169: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
168: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
167: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
166: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
165: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
164: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
163: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
162: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
161: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AttrStmtNode const*)
160: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AttrStmtNode const*)
159: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
158: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
157: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
156: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
155: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
154: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
153: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
152: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AttrStmtNode const*)
151: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AttrStmtNode const*)
150: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
149: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
148: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
147: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
146: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
145: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
144: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
143: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AttrStmtNode const*)
142: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AttrStmtNode const*)
141: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
140: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
139: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
138: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::LetStmtNode const*)
137: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
136: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
135: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
134: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
133: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
132: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
131: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
130: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
129: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
128: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
127: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
126: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
125: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
124: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
123: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
122: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
121: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
120: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
119: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
118: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
117: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
116: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
115: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
114: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
113: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
112: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
111: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
110: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
109: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
108: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
107: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
106: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
105: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
104: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
103: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*)
102: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
101: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
100: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
99: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
98: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
97: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
96: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
95: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
94: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
93: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
92: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
91: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
90: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
89: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
88: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
87: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
86: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
85: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
84: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
83: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
82: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
81: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
80: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
79: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
78: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
77: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
76: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
75: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*)
74: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
73: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
72: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
71: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
70: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
69: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
68: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
67: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
66: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
65: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
64: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
63: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
62: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
61: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
60: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
59: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
58: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
57: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
56: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
55: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
54: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
53: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
52: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
51: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
50: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
49: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
48: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
47: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
46: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
45: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
44: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*)
43: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
42: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
41: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
40: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
39: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
38: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
37: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
36: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
35: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
34: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
33: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
32: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
31: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
30: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
29: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
28: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
27: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
26: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
25: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
24: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
23: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
22: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
21: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
20: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
19: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
18: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
17: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
16: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
15: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
14: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
13: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*)
12: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
11: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
10: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
9: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
8: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
7: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
6: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AssertStmtNode const*)
5: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::AssertStmtNode const*)
4: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
3: tvm::codegen::CodeGenLLVM::VisitStmt_(tvm::tir::SeqStmtNode const*)
2: tvm::codegen::CodeGenCPU::VisitStmt_(tvm::tir::AttrStmtNode const*)
1: tvm::codegen::CodeGenCPU::CreateComputeScope(tvm::tir::AttrStmtNode const*)
0: tvm::codegen::CodeGenLLVM::GetVarValue(tvm::tir::VarNode const*) const
File "/home/xfplus/ws/project/tvm/src/target/llvm/codegen_llvm.cc", line 928
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
Check failed: (it != var_map_.end()) is false: cannot find variable buf_dyn_shmem
Also, I have dumped the IRModule as mod_mixed.txt
attached.
mod_mixed.txt
ok thanks. Can you send a PR to fix this problem? And when you do so, please try to "inline" the logic in DynamicSharedMemoryRewriterWrapper
that you added into DynamicSharedMemoryRewriter
, so that we don't have to add the wrapper class.
ok thanks. Can you send a PR to fix this problem? And when you do so, please try to "inline" the logic in
DynamicSharedMemoryRewriterWrapper
that you added intoDynamicSharedMemoryRewriter
, so that we don't have to add the wrapper class.
OK. I'll try it.
When I tuned resnet-50 using meta-schedule, I found that conv2d_winograd implementations raise an error "can not found variable buf_dyn_shmem". After I deep in, I think it's caused by
MergeDynamicSharedMemoryAllocations
and it can not work correctly when meets multi blocks in one prim_func. When we pass in a multi-blocks prim_func like the code shown below, this pass will generate buf_dyn_shmem allocator in the first statement, and reference it in the later statements which will lead to an error.Steps to reproduce
Expected behavior
MergeDynamicSharedMemoryAllocations
can process each statement independently like what I do now, or maybeSplitHostDevice
can do more for this case?generated irmodule:
Actual behavior
generated irmodule:
Environment
I'm using TVM v0.13.dev0 with commit: 4e07a8ed6687a08b6b27db21af019a5a179b9ee1 on a linux-x86_64 machine.
Something
And here is my workaround:
Triage