apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.42k stars 3.4k forks source link

[Dlight] Use 16x32 spatial x reduction thread extents in GEMV scheduling #17082

Closed csullivan closed 2 weeks ago

csullivan commented 4 weeks ago

Change to use 16x32 spatial x reduction thread extents regardless of workload size. This works around a lowering bug which I haven't tracked down yet.

Currently when the spatial dimension is larger than the reduction dimension, it uses a 4x64 thread layout. This implies two warps in the reduction dimension corresponding to blockDim.x=64. An illegal cuda instruction is encountered in the second warp during the __shfl_down_sync for the remainder portion of the computation (as a result of the rfactor, I believe). AFAICT the mask calculation used for this remainder shfl is incorrect and is causing the error. Specifically it occurs on the first thread of the second warp (two warps along x since blockDim.x = 64)

This is the relevant cuda causing the error:

if (((int)threadIdx.x) < 2) {
    red_buf0[0] = red_buf_staging[((((int)threadIdx.y) * 2) + ((int)threadIdx.x))];
  }
  mask[0] = (__activemask() & ((uint)(3 << (((int)threadIdx.y) * 2)))); // <<< likely the problem
  t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 1, 32);
  red_buf0[0] = (red_buf0[0] + t0[0]);
  if (((int)threadIdx.x) == 0) {
    ((volatile half*)red_result)[((int)threadIdx.y)] = red_buf0[0];
  }

The corresponding sass where the illegal instruction occurs:

   0x00007d9e97b92490 <+1936>:  WARPSYNC.ALL
   0x00007d9e97b924a0 <+1952>:  BAR.SYNC.DEFER_BLOCKING 0x0
   0x00007d9e97b924b0 <+1968>:  @!P1 VIADD R13, R5, 0x8
   0x00007d9e97b924c0 <+1984>:  @!P1 LEA R7, R17, R14, 0x1
   0x00007d9e97b924d0 <+2000>:  @!P1 PRMT R6, R2, 0x654, R13
   0x00007d9e97b924e0 <+2016>:  @!P1 LEA R7, R7, R6, 0x1
   0x00007d9e97b924f0 <+2032>:  @!P1 LDS.U16 R16, [R7]
   0x00007d9e97b92500 <+2048>:  IMAD.MOV.U32 R6, RZ, RZ, 0x3
   0x00007d9e97b92510 <+2064>:  SHF.L.U32 R17, R17, 0x1, RZ
   0x00007d9e97b92520 <+2080>:  VOTEU.ANY UR4, UPT, PT
   0x00007d9e97b92530 <+2096>:  SHF.L.U32 R3, R6, R17, RZ
   0x00007d9e97b92540 <+2112>:  LOP3.LUT R3, R3, UR4, RZ, 0xc0, !PT
   0x00007d9e97b92550 <+2128>:  ISETP.NE.AND P0, PT, R14, RZ, PT
   0x00007d9e97b92560 <+2144>:  PRMT R2, R2, 0x654, R5
   0x00007d9e97b92570 <+2160>:  PRMT R4, R16, 0x5410, R16
*> 0x00007d9e97b92580 <+2176>:  WARPSYNC R3
=> 0x00007d9e97b92590 <+2192>:  SHFL.DOWN PT, R3, R4, 0x1, 0x1f
   0x00007d9e97b925a0 <+2208>:  IMAD.IADD R17, R17, 0x1, R2
   0x00007d9e97b925b0 <+2224>:  HADD2 R16, R16.H0_H0, R3.H0_H0
   0x00007d9e97b925c0 <+2240>:  @!P0 STS.U16 [R17], R16
   0x00007d9e97b925d0 <+2256>:  WARPSYNC.ALL
   0x00007d9e97b925e0 <+2272>:  BAR.SYNC.DEFER_BLOCKING 0x0

Changing the thread extents to 16x32 (one warp along the reduction dimension) works around the issue. It also improves performance for the tested shapes by ~10%.

Utilizing (8, 2048, 4096) to avoid the error,

# 4x64
 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)             Name
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  --------------------------
     81.5           612214        101    6061.5    6048.0      5920      7872        188.5  moe_dequantize_gemv_kernel

# 16x32
 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)             Name
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  --------------------------
     79.9           555901        101    5504.0    5472.0      5439      6880        142.7  moe_dequantize_gemv_kernel
tqchen commented 4 weeks ago

@csullivan please help to fix the UT, in this case seems directly change the expected tir is fine