NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

avoid treating pointer to bool as bool when handling kir::asm in codegen #3274

Closed liqiangxl closed 1 month ago

liqiangxl commented 1 month ago

Issue: https://github.com/NVIDIA/Fuser/issues/3273 The original failure happened in test DistributedTransformerTest.Backward/__bfloat when shared memory persistent is used with async copy (after https://github.com/NVIDIA/Fuser/pull/3217, not merged yet). The reason is pointer to bool was treated as bool when handling kir::asm in codegen

Fix: If pointer, reture pointer type not the type it points to

Results: Added a unit test, error is fixed.

liqiangxl commented 1 month ago

!build

zasdfgbnm commented 1 month ago

Could you elaborate more on the problem? Why only pointer of bool is a problem, not pointer to others?

liqiangxl commented 1 month ago

Could you elaborate more on the problem? Why only pointer of bool is a problem, not pointer to others?

Because handle(const kir::Asm* asm_) in codegen has special treatment for bool type, e.g.

          if (val->dtype() == DataType::Bool) {
            indent() << "\"  .reg .pred p" << boolean_counter << "; \\n\"\n";
            indent() << "\"  setp.ne.b32 p" << boolean_counter << ", %"
                     << counter << ", 0;\\n\"\n";
            boolean_counter++;
          }

and convert to uint32, e.g.

             if (register_->dtype() == DataType::Bool) {
                code_ << "(uint32_t)(";
              }

Before Fix:

  asm volatile(
    "{\n"
    "  .reg .pred p0; \n"
    "  setp.ne.b32 p0, %0, 0;\n"
    "  .reg .pred p1; \n"
    "  setp.ne.b32 p1, %1, 0;\n"
    "  .reg .pred p2; \n"
    "  setp.ne.b32 p2, %3, 0;\n"
    "  cp.async.ca.shared.global [%0], [%1], %2, p0;\n"
    "}\n"
    :
    :"r"((uint32_t)((uint32_t)((toSmem(T1) + i0)))),
     "l"((uint32_t)(((T0.data + i0) + i1))),
     "n"(4LL),
     "r"((uint32_t)((!b3)))
  );

After fix:

  asm volatile(
    "{\n"
    "  .reg .pred p0; \n"
    "  setp.ne.b32 p0, %3, 0;\n"
    "  cp.async.ca.shared.global [%0], [%1], %2, p0;\n"
    "}\n"
    :
    :"r"((uint32_t)((toSmem(T1) + i0))),
     "l"(((T0.data + i0) + i1)),
     "n"(4LL),
     "r"((uint32_t)((!b3)))
  );
zasdfgbnm commented 1 month ago

Because handle(const kir::Asm* asm_) in codegen has special treatment for bool type

Oh, I see. I thought it was a dispatch on all dtypes. Didn't realize that it is only for bool.