Open liqiangxl opened 4 days ago
The following test fails when async copy is used.
using NVFuserTestParaBool = NVFuserFixtureParamTest<bool>; TEST_P(NVFuserTestParaBool, FusionCpAsyncRace) { NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0); Fusion fusion; FusionGuard fg(&fusion); int m = 64, n = 96; // Use the parameterized value for use_async const bool use_async = GetParam(); TensorView* tv0 = makeContigTensor(2); TensorView* tv1 = makeContigTensor(1); fusion.addInput(tv0); fusion.addInput(tv1); // copy tv0 to shared memory tv2 auto tv2 = set(tv0); tv2->setMemoryType(MemoryType::Shared); if (use_async) { tv2->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::CpAsync); tv2->definition()->as<LoadStoreOp>()->setCacheOp(CacheOp::Unspecified); } // copy tv1 to shared memory tv3 auto tv3 = set(tv1); tv3->setMemoryType(MemoryType::Shared); if (use_async) { tv3->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::CpAsync); tv3->definition()->as<LoadStoreOp>()->setCacheOp(CacheOp::Unspecified); } auto tv4 = broadcast(tv3, {true, false}); auto tv5 = add(tv2, tv4); fusion.addOutput(tv5); //[I0, I1] --> [I0/2, 2, 96] for (auto tv : {tv0, tv2, tv4, tv5}) { tv->split(0, 2); tv->axis(-1)->parallelize(ParallelType::TIDx); tv->axis(-2)->parallelize(ParallelType::TIDy); } //[I1] for (auto tv : {tv1, tv3}) { tv->axis(-1)->parallelize(ParallelType::TIDx); } inlineMost(); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({m, n}, options); at::Tensor t1 = at::randn({n}, options); KernelExecutor ke; ke.compile(&fusion, {t0, t1}); auto cg_outputs = ke.run({t0, t1}); testValidate(&fusion, cg_outputs, {t0, t1}, __LINE__, __FILE__); } INSTANTIATE_TEST_SUITE_P( , NVFuserTestParaBool, ::testing::Values(true, false));
The fusion_ir is:
Inputs: T0_g_float[ iS10{( ceilDiv(i0, 2) )}, ithreadIdx.y11{2}, ithreadIdx.x1{i2} ] T1_g_float[ ithreadIdx.x2{i3} ] Outputs: T5_g_float[ iS16{( ceilDiv(i0, 2) )}, ithreadIdx.y17{2}, ithreadIdx.x9{i2} ] ca_pos( 3 ) produce_pos( 3 ) %kernel_math { T2_s_float[ iS12{( ceilDiv(i0, 2) )}, ithreadIdx.y13{2}, ithreadIdx.x4{i2} ] ca_pos( 3 ) = CpAsync( T0_g_float[ iS10{( ceilDiv(i0, 2) )}, ithreadIdx.y11{2}, ithreadIdx.x1{i2} ] ) T3_s_float[ ithreadIdx.x5{i3} ] ca_pos( 1 ) = CpAsync( T1_g_float[ ithreadIdx.x2{i3} ] ) T4_l_float[ bS14{1}, bthreadIdx.y15{2}, ithreadIdx.x7{i3} ] ca_pos( 3 ) produce_pos( 3 ) = broadcast( T3_s_float[ ithreadIdx.x5{i3} ] ca_pos( 1 ) ) T5_g_float[ iS16{( ceilDiv(i0, 2) )}, ithreadIdx.y17{2}, ithreadIdx.x9{i2} ] ca_pos( 3 ) produce_pos( 3 ) = T2_s_float[ iS12{( ceilDiv(i0, 2) )}, ithreadIdx.y13{2}, ithreadIdx.x4{i2} ] ca_pos( 3 ) + T4_l_float[ bS14{1}, bthreadIdx.y15{2}, ithreadIdx.x7{i3} ] ca_pos( 3 ) produce_pos( 3 ); } // %kernel_math
The generated kernel for expr T3_s_float[ ithreadIdx.x5{i3} ] ca_pos( 1 ) = CpAsync( T1_g_float[ ithreadIdx.x2{i3} ] ) is:
T3_s_float[ ithreadIdx.x5{i3} ] ca_pos( 1 ) = CpAsync( T1_g_float[ ithreadIdx.x2{i3} ] )
b9 = ((nvfuser_index_t)threadIdx.y) == 0LL; asm volatile( "{\n" " .reg .pred p0; \n" " setp.ne.b32 p0, %3, 0;\n" " cp.async.ca.shared.global [%0], [%1], %2, p0;\n" "}\n" : :"r"((uint32_t)(i3)), "l"(ptr1), "n"(4LL), "r"((uint32_t)((!b9))) ); float T4[1LL]; asm volatile("cp.async.wait_all;\n"); __syncthreads();
This is wrong, becuase threads with threadIdx.y != 0 will write 0 to shared memory which pollutes the correct value. The correct one should be:
threadIdx.y != 0
0
if(b9){ asm volatile( "{\n" " .reg .pred p0; \n" " setp.ne.b32 p0, %3, 0;\n" " cp.async.ca.shared.global [%0], [%1], %2, p0;\n" "}\n" : :"r"((uint32_t)(i3)), "l"(ptr1), "n"(4LL), "r"((uint32_t)((!b9))) ); }
The following test fails when async copy is used.
The fusion_ir is:
The generated kernel for expr
T3_s_float[ ithreadIdx.x5{i3} ] ca_pos( 1 ) = CpAsync( T1_g_float[ ithreadIdx.x2{i3} ] )
is:This is wrong, becuase threads with
threadIdx.y != 0
will write0
to shared memory which pollutes the correct value. The correct one should be: