Closed liqiangxl closed 1 month ago
!build
Could you elaborate more on the problem? Why only pointer of bool is a problem, not pointer to others?
Could you elaborate more on the problem? Why only pointer of bool is a problem, not pointer to others?
Because handle(const kir::Asm* asm_)
in codegen has special treatment for bool type, e.g.
if (val->dtype() == DataType::Bool) {
indent() << "\" .reg .pred p" << boolean_counter << "; \\n\"\n";
indent() << "\" setp.ne.b32 p" << boolean_counter << ", %"
<< counter << ", 0;\\n\"\n";
boolean_counter++;
}
and convert to uint32, e.g.
if (register_->dtype() == DataType::Bool) {
code_ << "(uint32_t)(";
}
Before Fix:
asm volatile(
"{\n"
" .reg .pred p0; \n"
" setp.ne.b32 p0, %0, 0;\n"
" .reg .pred p1; \n"
" setp.ne.b32 p1, %1, 0;\n"
" .reg .pred p2; \n"
" setp.ne.b32 p2, %3, 0;\n"
" cp.async.ca.shared.global [%0], [%1], %2, p0;\n"
"}\n"
:
:"r"((uint32_t)((uint32_t)((toSmem(T1) + i0)))),
"l"((uint32_t)(((T0.data + i0) + i1))),
"n"(4LL),
"r"((uint32_t)((!b3)))
);
After fix:
asm volatile(
"{\n"
" .reg .pred p0; \n"
" setp.ne.b32 p0, %3, 0;\n"
" cp.async.ca.shared.global [%0], [%1], %2, p0;\n"
"}\n"
:
:"r"((uint32_t)((toSmem(T1) + i0))),
"l"(((T0.data + i0) + i1)),
"n"(4LL),
"r"((uint32_t)((!b3)))
);
Because
handle(const kir::Asm* asm_)
in codegen has special treatment for bool type
Oh, I see. I thought it was a dispatch on all dtypes. Didn't realize that it is only for bool.
Issue: https://github.com/NVIDIA/Fuser/issues/3273 The original failure happened in test
DistributedTransformerTest.Backward/__bfloat
when shared memory persistent is used with async copy (after https://github.com/NVIDIA/Fuser/pull/3217, not merged yet). The reason is pointer to bool was treated as bool when handling kir::asm in codegenFix: If pointer, reture pointer type not the type it points to
Results: Added a unit test, error is fixed.