csarofeen / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
26 stars 7 forks source link

Don't initialize if using cpAsync #2485

Closed zasdfgbnm closed 1 year ago

zasdfgbnm commented 1 year ago

With the change in https://github.com/csarofeen/pytorch/pull/2484, cp.async will automatically fill zero for out of bound elements. So there is no need to zero-init any more.

This PR also contains some misc cleanup.

Example kernel:

__global__ void kernel1(Tensor<__half, 2> T0, Tensor<__half, 2> T1, Tensor<float, 2> T4) {
  alignas(16) extern __shared__ char array[];
  unsigned smem_offset = 0;
  int i594;
  i594 = (T0.size[1] * ((nvfuser_index_t)threadIdx.z)) * 16;
  int i597;
  i597 = (T0.size[1] * 8) * ((nvfuser_index_t)threadIdx.y);
  int i528;
  i528 = ((nvfuser_index_t)threadIdx.x) / 4;
  int i599;
  i599 = T0.size[1] * i528;
  int i548;
  i548 = ((nvfuser_index_t)threadIdx.x) % 4;
  int i601;
  i601 = i548 * 8;
  int i600;
  i600 = ((((T0.size[1] * ((nvfuser_index_t)blockIdx.y)) * 128) + i594) + i597) + i599;
  int i602;
  i602 = i600 + i601;
  int i605;
  i605 = T0.size[1] * 32;
  int i804;
  i804 = ((nvfuser_index_t)threadIdx.z) * 1024;
  int i806;
  i806 = 512 * ((nvfuser_index_t)threadIdx.y);
  int i808;
  i808 = 16 * ((nvfuser_index_t)threadIdx.x);
  int i1187;
  i1187 = ((((T0.size[1] * ((nvfuser_index_t)blockIdx.x)) * 128) + i594) + i597) + i599;
  int i1188;
  i1188 = i1187 + i601;
  int i1785;
  i1785 = (i600 + 64) + i601;
  int i2412;
  i2412 = (i1187 + 64) + i601;
  int i2819;
  i2819 = ((nvfuser_index_t)threadIdx.z) * 4096;
  int i2690;
  i2690 = ((nvfuser_index_t)threadIdx.x) / 8;
  int i2721;
  i2721 = i2690 % 2;
  int i2821;
  i2821 = i2721 * 512;
  int i2823;
  i2823 = 64 * (((nvfuser_index_t)threadIdx.x) % 8);
  int i2728;
  i2728 = i2690 / 2;
  int i2825;
  i2825 = i2728 * 16;
  int i3004;
  i3004 = ((nvfuser_index_t)threadIdx.y) * 4096;
  int i3017;
  i3017 = i2728 * 512;
  int i3020;
  i3020 = i2721 * 16;
  int i3988;
  i3988 = ((nvfuser_index_t)blockIdx.y) * 128;
  int i3990;
  i3990 = ((nvfuser_index_t)threadIdx.y) * 64;
  int i4009;
  i4009 = i548 * 2;
  int i4010;
  i4010 = ((((((T1.size[0] * ((nvfuser_index_t)blockIdx.x)) * 128) + ((T1.size[0] * ((nvfuser_index_t)threadIdx.z)) * 64)) + (T1.size[0] * i528)) + i3988) + i3990) + i4009;
  int i3994;
  i3994 = T1.size[0] * 16;
  int i4001;
  i4001 = T1.size[0] * 8;
  int i4656;
  i4656 = (8 * ((nvfuser_index_t)threadIdx.x)) + 7;
  int i4658;
  i4658 = (i4656 % 32) - T0.size[1];
  int i4679;
  i4679 = ((((nvfuser_index_t)threadIdx.z) * 16) + (8 * ((nvfuser_index_t)threadIdx.y))) + (i4656 / 32);
  int i4681;
  i4681 = (i4679 + i3988) - T1.size[0];
  int i5335;
  i5335 = ((nvfuser_index_t)blockIdx.x) * 128;
  int i5343;
  i5343 = (i4679 + i5335) - T0.size[0];
  int i6915;
  i6915 = (((((nvfuser_index_t)threadIdx.z) * 64) + i528) + i5335) - T0.size[0];
  int i6919;
  i6919 = ((i3990 + i4009) + i3988) - T1.size[0];
  smem_offset = alignBufferSize(smem_offset, 16);
  __half* T7 = reinterpret_cast<__half*>(array + smem_offset);
  smem_offset += (12288 * sizeof(__half));
  smem_offset = alignBufferSize(smem_offset, 16);
  __half* T6 = reinterpret_cast<__half*>(array + smem_offset);
  smem_offset += (12288 * sizeof(__half));
  unsigned i649;
  i649 = toSmem(T7);
  unsigned i809;
  i809 = ((i649 + i804) + i806) + i808;
  unsigned i1233;
  i1233 = toSmem(T6);
  unsigned i1400;
  i1400 = ((i1233 + i804) + i806) + i808;
  unsigned i2833;
  i2833 = ((i1233 + i2819) + i2821) + i2823;
  unsigned i2834;
  i2834 = i2833 + i2825;
  unsigned i3019;
  i3019 = ((i649 + i3004) + i3017) + i2823;
  unsigned i3021;
  i3021 = i3019 + i3020;
  unsigned i3230;
  i3230 = (i2833 + 32) + i2825;
  unsigned i3505;
  i3505 = (i3019 + 32) + i3020;
  float T5[128];
  #pragma unroll
  for(nvfuser_index_t i142 = 0; i142 < 4; ++i142) {
    int i234;
    i234 = 32 * i142;
    #pragma unroll
    for(nvfuser_index_t i143 = 0; i143 < 4; ++i143) {
      Ampere::initM16N16K16TN<16>(reinterpret_cast<Array<float,8,8>*>(&T5[(i234 + (4 * i143))]));
    }
  }
  #pragma unroll
  for(nvfuser_index_t i136 = 0; i136 < 2; ++i136) {
    int i603;
    i603 = 32 * i136;
    int i604;
    i604 = i602 + i603;
    int i810;
    i810 = 8192 * i136;
    unsigned i811;
    i811 = i809 + i810;
    int i1189;
    i1189 = i1188 + i603;
    unsigned i1401;
    i1401 = i1400 + i810;
    bool b4677;
    b4677 = i4658 < (-i603);
    #pragma unroll
    for(nvfuser_index_t i130 = 0; i130 < 4; ++i130) {
      Ampere::cpAsyncCg<__half, 8>((i811 + (2048 * i130)),reinterpret_cast<Array<__half,8,8>*>(&T1[(i604 + (i605 * i130))]),(b4677 && (i4681 < (-(32 * i130)))));
    }
    #pragma unroll
    for(nvfuser_index_t i119 = 0; i119 < 4; ++i119) {
      Ampere::cpAsyncCg<__half, 8>((i1401 + (2048 * i119)),reinterpret_cast<Array<__half,8,8>*>(&T0[(i1189 + (i605 * i119))]),(b4677 && (i5343 < (-(32 * i119)))));
    }
    Ampere::cpAsyncCommit();
  }
  Ampere::cpAsyncPartialBarrier<1>();
  __syncthreads();
  #pragma unroll 1
  for(nvfuser_index_t i137 = 0; i137 < (ceilDiv(T0.size[1], 32)); ++i137) {
    int i1779;
    i1779 = 32 * i137;
    int i1786;
    i1786 = i1785 + i1779;
    int i2024;
    i2024 = 8192 * ((2 + i137) % 3);
    unsigned i2028;
    i2028 = i809 + i2024;
    int i2413;
    i2413 = i2412 + i1779;
    unsigned i2655;
    i2655 = i1400 + i2024;
    int i2827;
    i2827 = 8192 * (i137 % 3);
    unsigned i2835;
    i2835 = i2834 + i2827;
    unsigned i3022;
    i3022 = i3021 + i2827;
    unsigned i3231;
    i3231 = i3230 + i2827;
    unsigned i3506;
    i3506 = i3505 + i2827;
    bool b6037;
    b6037 = (i4658 + i1779) < -64;
    #pragma unroll
    for(nvfuser_index_t i130 = 0; i130 < 4; ++i130) {
      Ampere::cpAsyncCg<__half, 8>((i2028 + (2048 * i130)),reinterpret_cast<Array<__half,8,8>*>(&T1[(i1786 + (i605 * i130))]),(b6037 && (i4681 < (-(32 * i130)))));
    }
    #pragma unroll
    for(nvfuser_index_t i119 = 0; i119 < 4; ++i119) {
      Ampere::cpAsyncCg<__half, 8>((i2655 + (2048 * i119)),reinterpret_cast<Array<__half,8,8>*>(&T0[(i2413 + (i605 * i119))]),(b6037 && (i5343 < (-(32 * i119)))));
    }
    Ampere::cpAsyncCommit();
    Array<__half, 64, 8> T8;
    Array<__half, 64, 8> T9;
    #pragma unroll
    for(nvfuser_index_t i121 = 0; i121 < 4; ++i121) {
      Turing::ldMatrix (*reinterpret_cast<Array<__half,8,8>*>(&T8[(8 * i121)]),(i2835 + (1024 * i121)));
    }
    #pragma unroll
    for(nvfuser_index_t i132 = 0; i132 < 4; ++i132) {
      Turing::ldMatrix (*reinterpret_cast<Array<__half,8,8>*>(&T9[(8 * i132)]),(i3022 + (1024 * i132)));
    }
    #pragma unroll
    for(nvfuser_index_t i140 = 0; i140 < 1; ++i140) {
      int i3225;
      i3225 = 32 * i140;
      unsigned i3232;
      i3232 = i3231 + i3225;
      int i3259;
      i3259 = 32 * ((1 + i140) % 2);
      int i3298;
      i3298 = 32 * (i140 % 2);
      unsigned i3507;
      i3507 = i3506 + i3225;
      #pragma unroll
      for(nvfuser_index_t i121 = 0; i121 < 4; ++i121) {
        Turing::ldMatrix (*reinterpret_cast<Array<__half,8,8>*>(&T8[(i3259 + (8 * i121))]),(i3232 + (1024 * i121)));
      }
      __half T2[32];
      #pragma unroll
      for(nvfuser_index_t i123 = 0; i123 < 4; ++i123) {
        int i3269;
        i3269 = 8 * i123;
        #pragma unroll
        for(nvfuser_index_t i125 = 0; i125 < 8; ++i125) {
          T2[(i3269 + i125)] = 0.00000000000000000e+00;
        }
      }
      #pragma unroll
      for(nvfuser_index_t i123 = 0; i123 < 4; ++i123) {
        int i3278;
        i3278 = 8 * i123;
        int i3302;
        i3302 = i3298 + i3278;
        #pragma unroll
        for(nvfuser_index_t i125 = 0; i125 < 8; ++i125) {
          T2[(i3278 + i125)]
             = T8[(i3302 + i125)];
        }
      }
      #pragma unroll
      for(nvfuser_index_t i132 = 0; i132 < 4; ++i132) {
        Turing::ldMatrix (*reinterpret_cast<Array<__half,8,8>*>(&T9[(i3259 + (8 * i132))]),(i3507 + (1024 * i132)));
      }
      __half T3[32];
      #pragma unroll
      for(nvfuser_index_t i134 = 0; i134 < 4; ++i134) {
        int i3545;
        i3545 = 8 * i134;
        #pragma unroll
        for(nvfuser_index_t i135 = 0; i135 < 8; ++i135) {
          T3[(i3545 + i135)] = 0.00000000000000000e+00;
        }
      }
      #pragma unroll
      for(nvfuser_index_t i134 = 0; i134 < 4; ++i134) {
        int i3554;
        i3554 = 8 * i134;
        int i3578;
        i3578 = i3298 + i3554;
        #pragma unroll
        for(nvfuser_index_t i135 = 0; i135 < 8; ++i135) {
          T3[(i3554 + i135)]
             = T9[(i3578 + i135)];
        }
      }
      #pragma unroll
      for(nvfuser_index_t i142 = 0; i142 < 4; ++i142) {
        int i3645;
        i3645 = 32 * i142;
        #pragma unroll
        for(nvfuser_index_t i143 = 0; i143 < 4; ++i143) {
          Ampere::M16N16K16TN<16>(
            reinterpret_cast<Array<float,8,8>*>(&T5[(i3645 + (4 * i143))]),
            &(reinterpret_cast<Array<__half,8,8>*>(&T2)[i142]),
            &(reinterpret_cast<Array<__half,8,8>*>(&T3)[i143]));
        }
      }
    }
    __half T2[32];
    #pragma unroll
    for(nvfuser_index_t i123 = 0; i123 < 4; ++i123) {
      int i3654;
      i3654 = 8 * i123;
      #pragma unroll
      for(nvfuser_index_t i125 = 0; i125 < 8; ++i125) {
        T2[(i3654 + i125)] = 0.00000000000000000e+00;
      }
    }
    #pragma unroll
    for(nvfuser_index_t i123 = 0; i123 < 4; ++i123) {
      int i3662;
      i3662 = 8 * i123;
      int i3681;
      i3681 = 32 + i3662;
      #pragma unroll
      for(nvfuser_index_t i125 = 0; i125 < 8; ++i125) {
        T2[(i3662 + i125)]
           = T8[(i3681 + i125)];
      }
    }
    __half T3[32];
    #pragma unroll
    for(nvfuser_index_t i134 = 0; i134 < 4; ++i134) {
      int i3689;
      i3689 = 8 * i134;
      #pragma unroll
      for(nvfuser_index_t i135 = 0; i135 < 8; ++i135) {
        T3[(i3689 + i135)] = 0.00000000000000000e+00;
      }
    }
    #pragma unroll
    for(nvfuser_index_t i134 = 0; i134 < 4; ++i134) {
      int i3697;
      i3697 = 8 * i134;
      int i3716;
      i3716 = 32 + i3697;
      #pragma unroll
      for(nvfuser_index_t i135 = 0; i135 < 8; ++i135) {
        T3[(i3697 + i135)]
           = T9[(i3716 + i135)];
      }
    }
    #pragma unroll
    for(nvfuser_index_t i142 = 0; i142 < 4; ++i142) {
      int i3780;
      i3780 = 32 * i142;
      #pragma unroll
      for(nvfuser_index_t i143 = 0; i143 < 4; ++i143) {
        Ampere::M16N16K16TN<16>(
          reinterpret_cast<Array<float,8,8>*>(&T5[(i3780 + (4 * i143))]),
          &(reinterpret_cast<Array<__half,8,8>*>(&T2)[i142]),
          &(reinterpret_cast<Array<__half,8,8>*>(&T3)[i143]));
      }
    }
    Ampere::cpAsyncPartialBarrier<1>();
    __syncthreads();
  }
  #pragma unroll
  for(nvfuser_index_t i148 = 0; i148 < 4; ++i148) {
    int i3846;
    i3846 = 32 * i148;
    int i4011;
    i4011 = i4010 + (i3994 * i148);
    int i6894;
    i6894 = -(16 * i148);
    #pragma unroll
    for(nvfuser_index_t i149 = 0; i149 < 4; ++i149) {
      int i3848;
      i3848 = i3846 + (4 * i149);
      int i3997;
      i3997 = 16 * i149;
      int i4012;
      i4012 = i4011 + i3997;
      int i6920;
      i6920 = -i3997;
      #pragma unroll
      for(nvfuser_index_t i150 = 0; i150 < 2; ++i150) {
        int i3850;
        i3850 = i3848 + (2 * i150);
        int i3999;
        i3999 = 8 * i150;
        int i4013;
        i4013 = i4012 + i3999;
        int i6921;
        i6921 = i6920 - i3999;
        #pragma unroll
        for(nvfuser_index_t i151 = 0; i151 < 2; ++i151) {
          int i3852;
          i3852 = i3850 + (16 * i151);
          int i4014;
          i4014 = i4013 + (i4001 * i151);
          bool b6916;
          b6916 = i6915 < (i6894 - (8 * i151));
          #pragma unroll
          for(nvfuser_index_t i152 = 0; i152 < 2; ++i152) {
            if ((b6916 && (i6919 < (i6921 - i152)))) {
              T4[(i4014 + i152)]
                 = T5[(i3852 + i152)];
            }
          }
        }
      }
    }
  }
}
zasdfgbnm commented 1 year ago

cc @mmigdal-nv