Closed zasdfgbnm closed 1 year ago
With the change in https://github.com/csarofeen/pytorch/pull/2484, cp.async will automatically fill zero for out of bound elements. So there is no need to zero-init any more.
cp.async
This PR also contains some misc cleanup.
Example kernel:
__global__ void kernel1(Tensor<__half, 2> T0, Tensor<__half, 2> T1, Tensor<float, 2> T4) { alignas(16) extern __shared__ char array[]; unsigned smem_offset = 0; int i594; i594 = (T0.size[1] * ((nvfuser_index_t)threadIdx.z)) * 16; int i597; i597 = (T0.size[1] * 8) * ((nvfuser_index_t)threadIdx.y); int i528; i528 = ((nvfuser_index_t)threadIdx.x) / 4; int i599; i599 = T0.size[1] * i528; int i548; i548 = ((nvfuser_index_t)threadIdx.x) % 4; int i601; i601 = i548 * 8; int i600; i600 = ((((T0.size[1] * ((nvfuser_index_t)blockIdx.y)) * 128) + i594) + i597) + i599; int i602; i602 = i600 + i601; int i605; i605 = T0.size[1] * 32; int i804; i804 = ((nvfuser_index_t)threadIdx.z) * 1024; int i806; i806 = 512 * ((nvfuser_index_t)threadIdx.y); int i808; i808 = 16 * ((nvfuser_index_t)threadIdx.x); int i1187; i1187 = ((((T0.size[1] * ((nvfuser_index_t)blockIdx.x)) * 128) + i594) + i597) + i599; int i1188; i1188 = i1187 + i601; int i1785; i1785 = (i600 + 64) + i601; int i2412; i2412 = (i1187 + 64) + i601; int i2819; i2819 = ((nvfuser_index_t)threadIdx.z) * 4096; int i2690; i2690 = ((nvfuser_index_t)threadIdx.x) / 8; int i2721; i2721 = i2690 % 2; int i2821; i2821 = i2721 * 512; int i2823; i2823 = 64 * (((nvfuser_index_t)threadIdx.x) % 8); int i2728; i2728 = i2690 / 2; int i2825; i2825 = i2728 * 16; int i3004; i3004 = ((nvfuser_index_t)threadIdx.y) * 4096; int i3017; i3017 = i2728 * 512; int i3020; i3020 = i2721 * 16; int i3988; i3988 = ((nvfuser_index_t)blockIdx.y) * 128; int i3990; i3990 = ((nvfuser_index_t)threadIdx.y) * 64; int i4009; i4009 = i548 * 2; int i4010; i4010 = ((((((T1.size[0] * ((nvfuser_index_t)blockIdx.x)) * 128) + ((T1.size[0] * ((nvfuser_index_t)threadIdx.z)) * 64)) + (T1.size[0] * i528)) + i3988) + i3990) + i4009; int i3994; i3994 = T1.size[0] * 16; int i4001; i4001 = T1.size[0] * 8; int i4656; i4656 = (8 * ((nvfuser_index_t)threadIdx.x)) + 7; int i4658; i4658 = (i4656 % 32) - T0.size[1]; int i4679; i4679 = ((((nvfuser_index_t)threadIdx.z) * 16) + (8 * ((nvfuser_index_t)threadIdx.y))) + (i4656 / 32); int i4681; i4681 = (i4679 + i3988) - T1.size[0]; int i5335; i5335 = ((nvfuser_index_t)blockIdx.x) * 128; int i5343; i5343 = (i4679 + i5335) - T0.size[0]; int i6915; i6915 = (((((nvfuser_index_t)threadIdx.z) * 64) + i528) + i5335) - T0.size[0]; int i6919; i6919 = ((i3990 + i4009) + i3988) - T1.size[0]; smem_offset = alignBufferSize(smem_offset, 16); __half* T7 = reinterpret_cast<__half*>(array + smem_offset); smem_offset += (12288 * sizeof(__half)); smem_offset = alignBufferSize(smem_offset, 16); __half* T6 = reinterpret_cast<__half*>(array + smem_offset); smem_offset += (12288 * sizeof(__half)); unsigned i649; i649 = toSmem(T7); unsigned i809; i809 = ((i649 + i804) + i806) + i808; unsigned i1233; i1233 = toSmem(T6); unsigned i1400; i1400 = ((i1233 + i804) + i806) + i808; unsigned i2833; i2833 = ((i1233 + i2819) + i2821) + i2823; unsigned i2834; i2834 = i2833 + i2825; unsigned i3019; i3019 = ((i649 + i3004) + i3017) + i2823; unsigned i3021; i3021 = i3019 + i3020; unsigned i3230; i3230 = (i2833 + 32) + i2825; unsigned i3505; i3505 = (i3019 + 32) + i3020; float T5[128]; #pragma unroll for(nvfuser_index_t i142 = 0; i142 < 4; ++i142) { int i234; i234 = 32 * i142; #pragma unroll for(nvfuser_index_t i143 = 0; i143 < 4; ++i143) { Ampere::initM16N16K16TN<16>(reinterpret_cast<Array<float,8,8>*>(&T5[(i234 + (4 * i143))])); } } #pragma unroll for(nvfuser_index_t i136 = 0; i136 < 2; ++i136) { int i603; i603 = 32 * i136; int i604; i604 = i602 + i603; int i810; i810 = 8192 * i136; unsigned i811; i811 = i809 + i810; int i1189; i1189 = i1188 + i603; unsigned i1401; i1401 = i1400 + i810; bool b4677; b4677 = i4658 < (-i603); #pragma unroll for(nvfuser_index_t i130 = 0; i130 < 4; ++i130) { Ampere::cpAsyncCg<__half, 8>((i811 + (2048 * i130)),reinterpret_cast<Array<__half,8,8>*>(&T1[(i604 + (i605 * i130))]),(b4677 && (i4681 < (-(32 * i130))))); } #pragma unroll for(nvfuser_index_t i119 = 0; i119 < 4; ++i119) { Ampere::cpAsyncCg<__half, 8>((i1401 + (2048 * i119)),reinterpret_cast<Array<__half,8,8>*>(&T0[(i1189 + (i605 * i119))]),(b4677 && (i5343 < (-(32 * i119))))); } Ampere::cpAsyncCommit(); } Ampere::cpAsyncPartialBarrier<1>(); __syncthreads(); #pragma unroll 1 for(nvfuser_index_t i137 = 0; i137 < (ceilDiv(T0.size[1], 32)); ++i137) { int i1779; i1779 = 32 * i137; int i1786; i1786 = i1785 + i1779; int i2024; i2024 = 8192 * ((2 + i137) % 3); unsigned i2028; i2028 = i809 + i2024; int i2413; i2413 = i2412 + i1779; unsigned i2655; i2655 = i1400 + i2024; int i2827; i2827 = 8192 * (i137 % 3); unsigned i2835; i2835 = i2834 + i2827; unsigned i3022; i3022 = i3021 + i2827; unsigned i3231; i3231 = i3230 + i2827; unsigned i3506; i3506 = i3505 + i2827; bool b6037; b6037 = (i4658 + i1779) < -64; #pragma unroll for(nvfuser_index_t i130 = 0; i130 < 4; ++i130) { Ampere::cpAsyncCg<__half, 8>((i2028 + (2048 * i130)),reinterpret_cast<Array<__half,8,8>*>(&T1[(i1786 + (i605 * i130))]),(b6037 && (i4681 < (-(32 * i130))))); } #pragma unroll for(nvfuser_index_t i119 = 0; i119 < 4; ++i119) { Ampere::cpAsyncCg<__half, 8>((i2655 + (2048 * i119)),reinterpret_cast<Array<__half,8,8>*>(&T0[(i2413 + (i605 * i119))]),(b6037 && (i5343 < (-(32 * i119))))); } Ampere::cpAsyncCommit(); Array<__half, 64, 8> T8; Array<__half, 64, 8> T9; #pragma unroll for(nvfuser_index_t i121 = 0; i121 < 4; ++i121) { Turing::ldMatrix (*reinterpret_cast<Array<__half,8,8>*>(&T8[(8 * i121)]),(i2835 + (1024 * i121))); } #pragma unroll for(nvfuser_index_t i132 = 0; i132 < 4; ++i132) { Turing::ldMatrix (*reinterpret_cast<Array<__half,8,8>*>(&T9[(8 * i132)]),(i3022 + (1024 * i132))); } #pragma unroll for(nvfuser_index_t i140 = 0; i140 < 1; ++i140) { int i3225; i3225 = 32 * i140; unsigned i3232; i3232 = i3231 + i3225; int i3259; i3259 = 32 * ((1 + i140) % 2); int i3298; i3298 = 32 * (i140 % 2); unsigned i3507; i3507 = i3506 + i3225; #pragma unroll for(nvfuser_index_t i121 = 0; i121 < 4; ++i121) { Turing::ldMatrix (*reinterpret_cast<Array<__half,8,8>*>(&T8[(i3259 + (8 * i121))]),(i3232 + (1024 * i121))); } __half T2[32]; #pragma unroll for(nvfuser_index_t i123 = 0; i123 < 4; ++i123) { int i3269; i3269 = 8 * i123; #pragma unroll for(nvfuser_index_t i125 = 0; i125 < 8; ++i125) { T2[(i3269 + i125)] = 0.00000000000000000e+00; } } #pragma unroll for(nvfuser_index_t i123 = 0; i123 < 4; ++i123) { int i3278; i3278 = 8 * i123; int i3302; i3302 = i3298 + i3278; #pragma unroll for(nvfuser_index_t i125 = 0; i125 < 8; ++i125) { T2[(i3278 + i125)] = T8[(i3302 + i125)]; } } #pragma unroll for(nvfuser_index_t i132 = 0; i132 < 4; ++i132) { Turing::ldMatrix (*reinterpret_cast<Array<__half,8,8>*>(&T9[(i3259 + (8 * i132))]),(i3507 + (1024 * i132))); } __half T3[32]; #pragma unroll for(nvfuser_index_t i134 = 0; i134 < 4; ++i134) { int i3545; i3545 = 8 * i134; #pragma unroll for(nvfuser_index_t i135 = 0; i135 < 8; ++i135) { T3[(i3545 + i135)] = 0.00000000000000000e+00; } } #pragma unroll for(nvfuser_index_t i134 = 0; i134 < 4; ++i134) { int i3554; i3554 = 8 * i134; int i3578; i3578 = i3298 + i3554; #pragma unroll for(nvfuser_index_t i135 = 0; i135 < 8; ++i135) { T3[(i3554 + i135)] = T9[(i3578 + i135)]; } } #pragma unroll for(nvfuser_index_t i142 = 0; i142 < 4; ++i142) { int i3645; i3645 = 32 * i142; #pragma unroll for(nvfuser_index_t i143 = 0; i143 < 4; ++i143) { Ampere::M16N16K16TN<16>( reinterpret_cast<Array<float,8,8>*>(&T5[(i3645 + (4 * i143))]), &(reinterpret_cast<Array<__half,8,8>*>(&T2)[i142]), &(reinterpret_cast<Array<__half,8,8>*>(&T3)[i143])); } } } __half T2[32]; #pragma unroll for(nvfuser_index_t i123 = 0; i123 < 4; ++i123) { int i3654; i3654 = 8 * i123; #pragma unroll for(nvfuser_index_t i125 = 0; i125 < 8; ++i125) { T2[(i3654 + i125)] = 0.00000000000000000e+00; } } #pragma unroll for(nvfuser_index_t i123 = 0; i123 < 4; ++i123) { int i3662; i3662 = 8 * i123; int i3681; i3681 = 32 + i3662; #pragma unroll for(nvfuser_index_t i125 = 0; i125 < 8; ++i125) { T2[(i3662 + i125)] = T8[(i3681 + i125)]; } } __half T3[32]; #pragma unroll for(nvfuser_index_t i134 = 0; i134 < 4; ++i134) { int i3689; i3689 = 8 * i134; #pragma unroll for(nvfuser_index_t i135 = 0; i135 < 8; ++i135) { T3[(i3689 + i135)] = 0.00000000000000000e+00; } } #pragma unroll for(nvfuser_index_t i134 = 0; i134 < 4; ++i134) { int i3697; i3697 = 8 * i134; int i3716; i3716 = 32 + i3697; #pragma unroll for(nvfuser_index_t i135 = 0; i135 < 8; ++i135) { T3[(i3697 + i135)] = T9[(i3716 + i135)]; } } #pragma unroll for(nvfuser_index_t i142 = 0; i142 < 4; ++i142) { int i3780; i3780 = 32 * i142; #pragma unroll for(nvfuser_index_t i143 = 0; i143 < 4; ++i143) { Ampere::M16N16K16TN<16>( reinterpret_cast<Array<float,8,8>*>(&T5[(i3780 + (4 * i143))]), &(reinterpret_cast<Array<__half,8,8>*>(&T2)[i142]), &(reinterpret_cast<Array<__half,8,8>*>(&T3)[i143])); } } Ampere::cpAsyncPartialBarrier<1>(); __syncthreads(); } #pragma unroll for(nvfuser_index_t i148 = 0; i148 < 4; ++i148) { int i3846; i3846 = 32 * i148; int i4011; i4011 = i4010 + (i3994 * i148); int i6894; i6894 = -(16 * i148); #pragma unroll for(nvfuser_index_t i149 = 0; i149 < 4; ++i149) { int i3848; i3848 = i3846 + (4 * i149); int i3997; i3997 = 16 * i149; int i4012; i4012 = i4011 + i3997; int i6920; i6920 = -i3997; #pragma unroll for(nvfuser_index_t i150 = 0; i150 < 2; ++i150) { int i3850; i3850 = i3848 + (2 * i150); int i3999; i3999 = 8 * i150; int i4013; i4013 = i4012 + i3999; int i6921; i6921 = i6920 - i3999; #pragma unroll for(nvfuser_index_t i151 = 0; i151 < 2; ++i151) { int i3852; i3852 = i3850 + (16 * i151); int i4014; i4014 = i4013 + (i4001 * i151); bool b6916; b6916 = i6915 < (i6894 - (8 * i151)); #pragma unroll for(nvfuser_index_t i152 = 0; i152 < 2; ++i152) { if ((b6916 && (i6919 < (i6921 - i152)))) { T4[(i4014 + i152)] = T5[(i3852 + i152)]; } } } } } } }
cc @mmigdal-nv
With the change in https://github.com/csarofeen/pytorch/pull/2484,
cp.async
will automatically fill zero for out of bound elements. So there is no need to zero-init any more.This PR also contains some misc cleanup.
Example kernel: