NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

thread predicate is missing in async copy #3427

Open liqiangxl opened 4 days ago

liqiangxl commented 4 days ago

The following test fails when async copy is used.

using NVFuserTestParaBool = NVFuserFixtureParamTest<bool>;
TEST_P(NVFuserTestParaBool, FusionCpAsyncRace) {
  NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0);
  Fusion fusion;
  FusionGuard fg(&fusion);
  int m = 64, n = 96;

  // Use the parameterized value for use_async
  const bool use_async = GetParam();
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = makeContigTensor(1);
  fusion.addInput(tv0);
  fusion.addInput(tv1);

  // copy tv0 to shared memory tv2
  auto tv2 = set(tv0);
  tv2->setMemoryType(MemoryType::Shared);
  if (use_async) {
    tv2->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::CpAsync);
    tv2->definition()->as<LoadStoreOp>()->setCacheOp(CacheOp::Unspecified);
  }

  // copy tv1 to shared memory tv3
  auto tv3 = set(tv1);
  tv3->setMemoryType(MemoryType::Shared);
  if (use_async) {
    tv3->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::CpAsync);
    tv3->definition()->as<LoadStoreOp>()->setCacheOp(CacheOp::Unspecified); 
  }

  auto tv4 = broadcast(tv3, {true, false});
  auto tv5 = add(tv2, tv4);
  fusion.addOutput(tv5);

  //[I0, I1] --> [I0/2, 2, 96]
  for (auto tv : {tv0, tv2, tv4, tv5}) {
    tv->split(0, 2);
    tv->axis(-1)->parallelize(ParallelType::TIDx);    
    tv->axis(-2)->parallelize(ParallelType::TIDy);    
  }  
  //[I1]
  for (auto tv : {tv1, tv3}) {
    tv->axis(-1)->parallelize(ParallelType::TIDx);
  }

  inlineMost();

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor t0 = at::randn({m, n}, options);
  at::Tensor t1 = at::randn({n}, options);

  KernelExecutor ke;
  ke.compile(&fusion, {t0, t1});
  auto cg_outputs = ke.run({t0, t1});
  testValidate(&fusion, cg_outputs, {t0, t1}, __LINE__, __FILE__);
}
INSTANTIATE_TEST_SUITE_P(
    ,
    NVFuserTestParaBool,
    ::testing::Values(true, false));

The fusion_ir is:

Inputs:
  T0_g_float[ iS10{( ceilDiv(i0, 2) )}, ithreadIdx.y11{2}, ithreadIdx.x1{i2} ]
  T1_g_float[ ithreadIdx.x2{i3} ]
Outputs:
  T5_g_float[ iS16{( ceilDiv(i0, 2) )}, ithreadIdx.y17{2}, ithreadIdx.x9{i2} ] ca_pos( 3 ) produce_pos( 3 )

%kernel_math {
T2_s_float[ iS12{( ceilDiv(i0, 2) )}, ithreadIdx.y13{2}, ithreadIdx.x4{i2} ] ca_pos( 3 )
   = CpAsync( T0_g_float[ iS10{( ceilDiv(i0, 2) )}, ithreadIdx.y11{2}, ithreadIdx.x1{i2} ] )
T3_s_float[ ithreadIdx.x5{i3} ] ca_pos( 1 )
   = CpAsync( T1_g_float[ ithreadIdx.x2{i3} ] )
T4_l_float[ bS14{1}, bthreadIdx.y15{2}, ithreadIdx.x7{i3} ] ca_pos( 3 ) produce_pos( 3 )
   = broadcast( T3_s_float[ ithreadIdx.x5{i3} ] ca_pos( 1 ) )
T5_g_float[ iS16{( ceilDiv(i0, 2) )}, ithreadIdx.y17{2}, ithreadIdx.x9{i2} ] ca_pos( 3 ) produce_pos( 3 )
   = T2_s_float[ iS12{( ceilDiv(i0, 2) )}, ithreadIdx.y13{2}, ithreadIdx.x4{i2} ] ca_pos( 3 )
   + T4_l_float[ bS14{1}, bthreadIdx.y15{2}, ithreadIdx.x7{i3} ] ca_pos( 3 ) produce_pos( 3 );
} // %kernel_math 

The generated kernel for expr T3_s_float[ ithreadIdx.x5{i3} ] ca_pos( 1 ) = CpAsync( T1_g_float[ ithreadIdx.x2{i3} ] ) is:

b9 = ((nvfuser_index_t)threadIdx.y) == 0LL;
    asm volatile(
      "{\n"
      "  .reg .pred p0; \n"
      "  setp.ne.b32 p0, %3, 0;\n"
      "  cp.async.ca.shared.global [%0], [%1], %2, p0;\n"
      "}\n"
      :
      :"r"((uint32_t)(i3)),
       "l"(ptr1),
       "n"(4LL),
       "r"((uint32_t)((!b9)))
    );
    float T4[1LL];
    asm volatile("cp.async.wait_all;\n");
    __syncthreads();

This is wrong, becuase threads with threadIdx.y != 0 will write 0 to shared memory which pollutes the correct value. The correct one should be:

    if(b9){
      asm volatile(
        "{\n"
        "  .reg .pred p0; \n"
        "  setp.ne.b32 p0, %3, 0;\n"
        "  cp.async.ca.shared.global [%0], [%1], %2, p0;\n"
        "}\n"
        :
        :"r"((uint32_t)(i3)),
         "l"(ptr1),
         "n"(4LL),
         "r"((uint32_t)((!b9)))
      );
    }