NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
256 stars 51 forks source link

Layer norm with fused ops performance regression #383

Closed naoyam closed 1 year ago

naoyam commented 1 year ago

This is on a pytorch-A100 node:

 ./bin/nvfuser_bench --benchmark_filter="NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/" --benchmark_min_time=0.01
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                                             Time             CPU   Iterations UserCounters...
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/768/manual_time         27.9 us          105 us          496 bytes_per_second=841.259G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 1 / vectorize / factor 8/Launch_Parameters[block(1/1/96)/grid(1/1/8192)/384]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/1024/manual_time        34.6 us          114 us          399 bytes_per_second=903.322G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 1 / vectorize / factor 8/Launch_Parameters[block(1/1/128)/grid(1/1/8192)/512]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/1536/manual_time        47.9 us          128 us          290 bytes_per_second=979.304G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 2 / vectorize / factor 8/Launch_Parameters[block(1/1/96)/grid(1/1/8192)/384]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/1280/manual_time        47.3 us          126 us          293 bytes_per_second=825.381G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 2 / vectorize / factor 8/Launch_Parameters[block(1/1/96)/grid(1/1/8192)/384]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/1600/manual_time        60.0 us          140 us          232 bytes_per_second=814.405G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 2 / vectorize / factor 8/Launch_Parameters[block(1/1/128)/grid(1/1/8192)/512]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/2048/manual_time        60.6 us          143 us          230 bytes_per_second=1030.95G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 2 / vectorize / factor 8/Launch_Parameters[block(1/1/128)/grid(1/1/8192)/512]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/2560/manual_time        90.0 us          168 us          155 bytes_per_second=868.143G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 3 / vectorize / factor 8/Launch_Parameters[block(1/1/128)/grid(1/1/8192)/512]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/4096/manual_time         130 us          212 us          107 bytes_per_second=960.641G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 4 / vectorize / factor 8/Launch_Parameters[block(1/1/128)/grid(1/1/8192)/512]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/5140/manual_time         432 us          506 us           32 bytes_per_second=362.976G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain: unroll / factor 2 // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 5 / vectorize / factor 4/Launch_Parameters[block(1/1/288)/grid(1/1/4096)/1152]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/12288/manual_time        430 us          502 us           32 bytes_per_second=871.69G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 6 / vectorize / factor 8/Launch_Parameters[block(1/1/256)/grid(1/1/8192)/1024]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/16384/manual_time       1740 us         1817 us            8 bytes_per_second=287.485G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 8 / vectorize / factor 8/Launch_Parameters[block(1/1/256)/grid(1/1/8192)/1024]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/20480/manual_time       1045 us         1118 us           13 bytes_per_second=598.259G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 10 / vectorize / factor 8/Launch_Parameters[block(1/1/256)/grid(1/1/8192)/1024]

The last four cases seem slow, in particular LayerNormFusedOp_fp16/8192/16384 is only 287 GB/s. Is this expected? Can we improve the performance by adjusting heuristics? Here's ptxas info for the case:

ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel7ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel7ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_
ptxas         .     704 bytes stack frame, 704 bytes spill stores, 746 bytes spill loads
ptxas info    : Used 128 registers, 16 bytes smem, 464 bytes cmem[0], 8 bytes cmem[2]
liqiangxl commented 1 year ago

The low performance on current main branch is expected as the redundant casts are not removed. I checked the perf using jie's branch #355 . All looks good except for the one with a hidden size of 5140.

-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                                             Time             CPU   Iterations UserCounters...
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/768/manual_time         26.7 us          108 us          519 bytes_per_second=876.955G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 1 / vectorize / factor 8/Launch_Parameters[block(1/1/96)/grid(1/1/8192)/384]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/1024/manual_time        32.8 us          117 us          423 bytes_per_second=952.711G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 1 / vectorize / factor 8/Launch_Parameters[block(1/1/128)/grid(1/1/8192)/512]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/1536/manual_time        45.5 us          130 us          304 bytes_per_second=1030.3G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 2 / vectorize / factor 8/Launch_Parameters[block(1/1/96)/grid(1/1/8192)/384]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/1280/manual_time        44.3 us          127 us          313 bytes_per_second=881.201G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 2 / vectorize / factor 8/Launch_Parameters[block(1/1/96)/grid(1/1/8192)/384]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/1600/manual_time        55.4 us          142 us          251 bytes_per_second=882.336G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 2 / vectorize / factor 8/Launch_Parameters[block(1/1/128)/grid(1/1/8192)/512]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/2048/manual_time        58.0 us          146 us          241 bytes_per_second=1078.75G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 2 / vectorize / factor 8/Launch_Parameters[block(1/1/128)/grid(1/1/8192)/512]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/2560/manual_time        77.5 us          156 us          180 bytes_per_second=1008.31G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 3 / vectorize / factor 8/Launch_Parameters[block(1/1/128)/grid(1/1/8192)/512]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/4096/manual_time         105 us          194 us          133 bytes_per_second=1.1619T/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 4 / vectorize / factor 8/Launch_Parameters[block(1/1/128)/grid(1/1/8192)/512]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/5140/manual_time         304 us          385 us           46 bytes_per_second=515.708G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain: unroll / factor 2 // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 5 / vectorize / factor 4/Launch_Parameters[block(1/1/288)/grid(1/1/4096)/1152]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/12288/manual_time        328 us          410 us           42 bytes_per_second=1.11537T/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 6 / vectorize / factor 8/Launch_Parameters[block(1/1/256)/grid(1/1/8192)/1024]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/16384/manual_time        514 us          602 us           27 bytes_per_second=973.484G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 8 / vectorize / factor 8/Launch_Parameters[block(1/1/256)/grid(1/1/8192)/1024]
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/20480/manual_time        672 us          747 us           21 bytes_per_second=929.899G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain:  // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 10 / vectorize / factor 8/Launch_Parameters[block(1/1/256)/grid(1/1/8192)/1024]

The perf of 5140 is low because: (1) it is vectorized by 4 while other cases are vectorized by 8. The kernel is using a max_unroll = 16 / max_input_dtype_size = 8, additional 2 unroll is put to iter_unroll_factor. This iter_unroll_factor increased persistent buffer size and extra code to process unroll leads to high register useage. e.g. it needs 155 registers (if I run with PYTORCH_NVFUSER_MAX_REG_COUNT=255) while the estimated register is 80. (2) it is not a multiple of warp size and bdimx is padded.

If I set iter_unroll_factor = 1, the bandwidth increased from 515.708G/s to 806.959G/s. So we should probabally stop using iter_unroll_factor and extend the use of multiple_reds_per_blk to take care of cases where we have a large iter domain but small reduction domain. Working on a PR.

naoyam commented 1 year ago

Didn't you also find some issue with selection of persistent buffers? Can't understand why the castOp itself could make such a large difference. Have you already filed an issue?

Do you have any idea why the register usage is so much higher than the estimate? Using a different scheduling heuristic may be a way to go, but before making a decision, we should at least try to understand why.

liqiangxl commented 1 year ago

Didn't you also find some issue with selection of persistent buffers? Can't understand why the castOp itself could make such a large difference. Have you already filed an issue?

see https://github.com/NVIDIA/Fuser/issues/343

Do you have any idea why the register usage is so much higher than the estimate? Using a different scheduling heuristic may be a way to go, but before making a decision, we should at least try to understand why.

Because iter_unroll_factor = 2 means there is an additional unrolled loop on top of vectorization and persistent batch. So the registers to store the buffers and overhead are doubled, if succssfully unrolled. This explains why used register is 155 while estimated is 80.

naoyam commented 1 year ago

Doesn't that mean we should also fix how to estimate register usage?

naoyam commented 1 year ago

I looked at the generated code of the 5140 case. A couple of things I noticed:

Here's the initial read of the input tensor:

  if ((((((i185 * 16) + i546) + 3) < T0.size[0]) && ((1 + i3549) < T1.size[0]))) {
    Array<__half, 40, 4> T42;
    #pragma unroll
    for(nvfuser_index_t i396 = 0; i396 < 2; ++i396) {
      int i487;
      i487 = 20 * i396;
      #pragma unroll
      for(nvfuser_index_t i397 = 0; i397 < 5; ++i397) {
        T42.set(__half(0));
      }
    }
    NVFUSER_UPDATE_MAGIC_ZERO;
    #pragma unroll
    for(nvfuser_index_t i396 = 0; i396 < 2; ++i396) {
      int i551;
      i551 = i549 + (T0.size[0] * i396);
      int i577;
      i577 = 20 * i396;
      #pragma unroll
      for(nvfuser_index_t i397 = 0; i397 < 5; ++i397) {
        loadGlobalToLocal<__half, 4, false>(&T42[(i577 + (4 * i397))],  &T1[(i551 + (i552 * (i397 + nvfuser_zero)))]);
      }
    }
    NVFUSER_UPDATE_MAGIC_ZERO;
    // Alias Allocation - register
    auto& T46 = T42;
    #pragma unroll
    for(nvfuser_index_t i435 = 0; i435 < 2; ++i435) {
      int i742;

I wonder why T42 is not inlined with the i435 loop. It seems the kernel reads the whole input tensor once, does the normalization, and then writes the whole output. I suppose T42 is not a persistent buffer, but it's merely created due to the computeAt position. It would be interesting to try inlining it with the computation loop.

I don't think this is a scheduling issue, but it seems there're two identical sum reductions of T15, likely because the input fusion has sum and variance. Can this benchmark use Welford instead?

Assuming T42 is not a persistent buffer, the other relatively large buffers are:

 __half T14[20];
 float T15[20];

These are inside the unswitch loop of size 2, so the compile may double the size, which should still result in 60 registers, meaning the additional "overhead" is nearly 100 registers, which seems quite high.

liqiangxl commented 1 year ago

Doesn't that mean we should also fix how to estimate register usage?

I think we don't need to bother modifying the current register usage estimation to account or iter_unroll_factor. Instead, we can abandon this iter_unroll_factor for persistent kernels. It uses too many registers and leads to low occupancy and slow kernels. Also, it is not efficient even for small cases not bounded by register usages, see #386

As a solution, we can use multiple_reds_per_blk as it uses less registers and lead to faster kernel.

naoyam commented 1 year ago

Have you actually fixed the estimate and compared the results?

liqiangxl commented 1 year ago

Have you actually fixed the estimate and compared the results?

On main branch with iter_unroll_factor = 2 generated by heuristics, run with PYTORCH_NVFUSER_MAX_REG_COUNT=255 to remove register spills:

ptxas info    : 307 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_
ptxas         .     0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 157 registers, 16 bytes smem, 464 bytes cmem[0], 8 bytes cmem[2]

Launch Parameters: BlockDim.x = 288, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4096, GridDim.y = -1, GridDim.z = -1, Smem Size = 1152
kernel1 run in 0.715776 ms, achieved: 235.365 GB/s

It is slower than the original version with register spills:

ptxas info    : 307 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_
ptxas         .     256 bytes stack frame, 248 bytes spill stores, 296 bytes spill loads
ptxas info    : Used 72 registers, 16 bytes smem, 464 bytes cmem[0], 8 bytes cmem[2]

Launch Parameters: BlockDim.x = 288, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4096, GridDim.y = -1, GridDim.z = -1, Smem Size = 1152
kernel1 run in 0.339968 ms, achieved: 495.543 GB/s
naoyam commented 1 year ago

Hmm, I don't get the same result as yours. Which version of the toolkit are you using?

My branch is:

commit ad96bd57f3733800fc1998935caa4adb7180d6bf (HEAD -> main, origin/main, origin/HEAD)
Author: Gao, Xiang <qasdfgtyuiop@gmail.com>
Date:   Tue May 23 14:56:57 2023 -0700

    Allocation domain `SchedulerRuntimeInfo` fix (#392)

I'm using 11.8 and the 5140 case is by default:


ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_
ptxas         .     208 bytes stack frame, 810 bytes spill stores, 924 bytes spill loads
ptxas info    : Used 72 registers, 16 bytes smem, 464 bytes cmem[0], 8 bytes cmem[2]

----------------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                                            Time             CPU   Iterations UserCounters...
----------------------------------------------------------------------------------------------------------------------------------------------------------------------
NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/5140/manual_time        433 us          507 us         1621 bytes_per_second=362.589G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain: unroll / factor 2 // Inner Reduction Domain: cross block reduction / pad to warp / persistent batch - 5 / vectorize / factor 4/Launch_Parameters[block(1/1/288)/grid(1/1/4096)/1152]

Using the max reg environment variable results in an error:

PYTORCH_NVFUSER_MAX_REG_COUNT=255 PYTORCH_NVFUSER_DUMP=ptxas_verbose   ./bin/nvfuser_bench --benchmark_filter="NvFuserScheduler_LayerNormFusedOp_fp16___GRAPH/NvFuserScheduler_LayerNormFusedOp_fp16/8192/5140/*"
The number of inputs is very large. NvFuserScheduler_TIMM_BatchNorm_nhwc_fp16___GRAPH/NvFuserScheduler_TIMM_BatchNorm_nhwc_fp16 will be repeated at least 270 times.
The number of inputs is very large. NvFuserScheduler_TIMM_BatchNorm_nhwc_fp16___GRAPH/NvFuserScheduler_TIMM_BatchNorm_nhwc_fp16 will be repeated at least 120 times.
The number of inputs is very large. NvFuserScheduler_TIMM_BatchNorm_nhwc_BWD_fp16___GRAPH/NvFuserScheduler_TIMM_BatchNorm_nhwc_BWD_fp16 will be repeated at least 270 times.
The number of inputs is very large. NvFuserScheduler_TIMM_BatchNorm_nhwc_BWD_fp16___GRAPH/NvFuserScheduler_TIMM_BatchNorm_nhwc_BWD_fp16 will be repeated at least 120 times.
The number of inputs is very large. NvFuserScheduler_TIMM_LayerNorm_fp16___GRAPH/NvFuserScheduler_TIMM_LayerNorm_fp16 will be repeated at least 270 times.
The number of inputs is very large. NvFuserScheduler_TIMM_LayerNorm_fp16___GRAPH/NvFuserScheduler_TIMM_LayerNorm_fp16 will be repeated at least 120 times.
The number of inputs is very large. Baseline_TIMM_LayerNorm_fp16 will be repeated at least 270 times.
The number of inputs is very large. Baseline_TIMM_LayerNorm_fp16 will be repeated at least 120 times.
2023-05-23T17:28:50-07:00
Running ./bin/nvfuser_bench
Run on (64 X 3500 MHz CPU s)
CPU Caches:
  L1 Data 32 KiB (x32)
  L1 Instruction 32 KiB (x32)
  L2 Unified 512 KiB (x32)
  L3 Unified 16384 KiB (x8)
Load Average: 0.08, 0.12, 0.42
***WARNING*** CPU scaling is enabled, the benchmark real time measurements may be noisy and will incur extra overhead.
__tmp_kernel1.cu(9571): warning #550-D: variable "i487" was set but never used

__tmp_kernel1.cu(9818): warning #550-D: variable "i1827" was set but never used

ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_
ptxas         .     0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 182 registers, 16 bytes smem, 464 bytes cmem[0], 8 bytes cmem[2]

zsh: abort (core dumped)  PYTORCH_NVFUSER_MAX_REG_COUNT=255 PYTORCH_NVFUSER_DUMP=ptxas_verbose
liqiangxl commented 1 year ago

Initially (first post, 4 days ago), I used jie's branch https://github.com/NVIDIA/Fuser/pull/355. In yesterday's post, I retested on the main branch using case FusionLayerNormFusedOpsRedundantCast_CUDA, needs to modify the batch size and hidden size. And also mannually remove the redandant cast.

naoyam commented 1 year ago

Please post what you actually measured.

liqiangxl commented 1 year ago

What I did on nvdl-a112-d001 using our nightly build pulled on 05/22/23 (1) code base:

commit ad96bd57f3733800fc1998935caa4adb7180d6bf (HEAD -> main, origin/main, origin/HEAD)
Author: Gao, Xiang <qasdfgtyuiop@gmail.com>
Date:   Tue May 23 14:56:57 2023 -0700

    Allocation domain `SchedulerRuntimeInfo` fix (#392)

(2) Test case:

TEST_F(NVFuserTest, FusionLayerNormFusedOpsRedundantCast_CUDA) {
  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
  auto fusion = fusion_ptr.get();
  FusionGuard fg(fusion);

  const float kEps = 1e-5;
  const int batch_size = 8192;
  const int hidden_size = 5140;
  {
    DataType dtype = DataType::Half;
    auto tv0 = makeContigTensor(1, dtype);
    auto tv1 = makeContigTensor(2, dtype);
    auto tv2 = makeContigTensor(1, dtype);
    auto tv3 = makeContigTensor(1, dtype);
    auto tv4 = makeContigTensor(1, dtype);

    fusion->addInput(tv0);
    fusion->addInput(tv1);
    fusion->addInput(tv2);
    fusion->addInput(tv3);
    fusion->addInput(tv4);
    auto tv5 = broadcast(tv0, {true, false});
    auto tv6 = castOp(DataType::Float, tv1);
    auto tv7 = castOp(DataType::Float, tv5);
    auto tv8 = add(tv6, tv7);
    auto tv9 = castOp(DataType::Half, tv8);
    auto tv10 = broadcast(tv2, {true, false});
    auto tv11 = castOp(DataType::Float, tv9);
    auto tv12 = castOp(DataType::Float, tv10);
    auto tv13 = add(tv11, tv12);
    // auto tv14 = castOp(DataType::Half, tv13);
    // auto tv15 = castOp(DataType::Float, tv14);
    auto tv16 = variance(tv13, {1}, false, false);
    auto tv17 = broadcast(tv16, {false, true});
    auto tv18 = sum(tv13, {1}, false);
    auto tv19 = broadcast(tv18, {false, true});

    nvfuser::Val* num_features =
        IrBuilder::create<Double>(1, dtype = DataType::Double);
    num_features = mul(num_features, tv0->getLeafDomain()[0]->extent());
    auto s20 = num_features;

    auto s21 = reciprocal(s20);
    auto tv22 = mul(tv19, s21);
    auto s23 = IrBuilder::create<Double>(kEps, dtype = DataType::Double);
    auto tv24 = add(tv17, s23);
    auto tv25 = rsqrt(tv24);
    auto tv26 = broadcast(tv22, {false, false});
    // auto tv27 = castOp(DataType::Float, tv14);
    auto tv28 = sub(tv13, tv26);
    auto tv29 = broadcast(tv25, {false, false});
    auto tv30 = mul(tv28, tv29);
    auto tv31 = broadcast(tv4, {true, false});
    auto tv32 = castOp(DataType::Float, tv31);
    auto tv33 = mul(tv30, tv32);
    auto tv34 = broadcast(tv3, {true, false});
    auto tv35 = castOp(DataType::Float, tv34);
    auto tv36 = add(tv33, tv35);
    auto tv37 = castOp(DataType::Half, tv36);
    fusion->addOutput(tv37);
  }

  auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
  std::vector<c10::IValue> inputs;
  std::vector<at::Tensor> outputs;

  {
    auto t0 = at::randn({hidden_size}, options);
    auto t1 = at::randn({batch_size, hidden_size}, options);
    auto t2 = at::randn({hidden_size}, options);
    auto t3 = at::randn({hidden_size}, options);
    auto t4 = at::randn({hidden_size}, options);
    inputs.emplace_back(t0);
    inputs.emplace_back(t1);
    inputs.emplace_back(t2);
    inputs.emplace_back(t3);
    inputs.emplace_back(t4);
    auto t5 = t0.unsqueeze(0).expand({batch_size, hidden_size});
    auto t6 = t1.to(at::kFloat);
    auto t7 = t5.to(at::kFloat);
    auto t8 = at::add(t6, t7);
    auto t9 = t8.to(at::kHalf);
    auto t10 = t2.unsqueeze(0).expand({batch_size, hidden_size});
    auto t11 = t9.to(at::kFloat);
    auto t12 = t10.to(at::kFloat);
    auto t13 = at::add(t11, t12);
    auto t14 = t13.to(at::kHalf);
    auto aten_outputs = at::native_layer_norm(t14, {hidden_size}, t4, t3, kEps);
    auto t33 = std::get<0>(aten_outputs);
    outputs.emplace_back(t33);
  }

  FusionExecutorCache fec(std::move(fusion_ptr));
  auto cg_outputs = fec.runFusionWithInputs(inputs);
  testValidate(fusion, cg_outputs, inputs, outputs, __LINE__, __FILE__);
}

(3) Result with PYTORCH_NVFUSER_DUMP=ptxas_verbose,dump_eff_bandwidth,launch_param ./nvfuser_tests --gtest_filter=NVFuserTest.FusionLayerNormFusedOpsRedundantCast_CUDA

ptxas info    : 307 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_
ptxas         .     256 bytes stack frame, 248 bytes spill stores, 296 bytes spill loads
ptxas info    : Used 72 registers, 16 bytes smem, 464 bytes cmem[0], 8 bytes cmem[2]

Launch Parameters: BlockDim.x = 288, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4096, GridDim.y = -1, GridDim.z = -1, Smem Size = 1152
kernel1 run in 0.339968 ms, achieved: 495.543 GB/s

(4) Result with PYTORCH_NVFUSER_MAX_REG_COUNT=255 PYTORCH_NVFUSER_DUMP=ptxas_verbose,dump_eff_bandwidth,launch_param ./nvfuser_tests --gtest_filter=NVFuserTest.FusionLayerNormFusedOpsRedundantCast_CUDA

ptxas info    : 307 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_
ptxas         .     0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 157 registers, 16 bytes smem, 464 bytes cmem[0], 8 bytes cmem[2]

Launch Parameters: BlockDim.x = 288, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4096, GridDim.y = -1, GridDim.z = -1, Smem Size = 1152
kernel1 run in 0.715776 ms, achieved: 235.365 GB/s
naoyam commented 1 year ago

I was able to reproduce your results.

Any idea why the 255 case is so much slower than the default case?

naoyam commented 1 year ago

I'm looking into the fusion. There's only one persistent buffer:

   float T13[20];

But still the register usage looks terrible:

ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_
ptxas         .     248 bytes stack frame, 236 bytes spill stores, 280 bytes spill loads
ptxas info    : Used 72 registers, 16 bytes smem, 464 bytes cmem[0], 8 bytes cmem[2]

Launch Parameters: BlockDim.x = 288, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4096, GridDim.y = -1, GridDim.z = -1, Smem Size = 1152
kernel1 run in 0.321536 ms, achieved: 523.95 GB/s

Looks like it at least requires 130 registers:

ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_
ptxas         .     0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 130 registers, 16 bytes smem, 464 bytes cmem[0], 8 bytes cmem[2]

Launch Parameters: BlockDim.x = 288, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4096, GridDim.y = -1, GridDim.z = -1, Smem Size = 1152
kernel1 run in 0.726016 ms, achieved: 232.045 GB/s

With iter_unroll_factor disabled, the register usage is decreased significantly:

ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_
ptxas         .     0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 45 registers, 16 bytes smem, 464 bytes cmem[0], 8 bytes cmem[2]

Launch Parameters: BlockDim.x = 288, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 8192, GridDim.y = -1, GridDim.z = -1, Smem Size = 1152
kernel1 run in 0.200704 ms, achieved: 839.389 GB/s

The last case only uses 45 registers, which seems pretty good considering there's a buffer of float[20].

With iter_unroll_factor==2, the register count jumps to 130, which seems too much.

liqiangxl commented 1 year ago

I was able to reproduce your results.

Any idea why the 255 case is so much slower than the default case?

Maybe due to low occupancy. 255 case only can have 1 active block per SM while original case with 72 registers can have 3 blocks per SM.

naoyam commented 1 year ago

One reason seems to be this unswitch:

https://github.com/NVIDIA/Fuser/blob/main/csrc/scheduler/reduction_utils.cpp#L256

Unswitching the iteration domain means unswitching the whole reduction domain, which doesn't work well unless the extent is completely evenly divisible, which is not the case with this problem size. So, aside from the register usage, it is understandable that the efficiency gets reduced.

After disabling it, here's what I got:


ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_
ptxas         .     40 bytes stack frame, 68 bytes spill stores, 148 bytes spill loads
ptxas info    : Used 72 registers, 16 bytes smem, 464 bytes cmem[0], 8 bytes cmem[2]

Launch Parameters: BlockDim.x = 288, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4096, GridDim.y = -1, GridDim.z = -1, Smem Size = 1152
kernel1 run in 0.207872 ms, achieved: 810.444 GB/s

Interestingly, the register usage is also reduced significantly. Looks like 88 registers are enough for this kernel.

ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_
ptxas         .     0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 88 registers, 16 bytes smem, 464 bytes cmem[0], 8 bytes cmem[2]

Launch Parameters: BlockDim.x = 288, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4096, GridDim.y = -1, GridDim.z = -1, Smem Size = 1152
kernel1 run in 0.246784 ms, achieved: 682.656 GB/s

It seems to align well with the no iter-unroll-factor case, which uses 45 registers. With iter_unroll_factor==2, the persistent buffer size is effectively doubled, so additional 20 registers. Moreover, since we unroll the iteration domain, the input cache is allocated as:

  Array<__half, 40, 4> T39;

That's another 20 registers. Overall, 45 + 20 + 20 = 85, so that's pretty close to the actual usage, which is good. However, it also means there's almost no benefit of increasing the per-thread work as the register usage efficiency is virtually improved.

The occupancy is 2 blocks per SM with this configuration, whereas it's 4 with no iter unroll. Since the former has 2x more work per thread than the latter, it seems our generated code is not able to exploit instruction-level parallelism well. Here's what ncu shows about the number of ready warps with iter unroll:

Eligible Warps Per Scheduler [warp] 0.71

Here's the number of ready warps with no iter unroll:

Eligible Warps Per Scheduler [warp] 1.29

This is disappointing. I suspect the input reading and the main computation loop are not overlapped well, whereas in the latter config, the warp scheduler should automatically switch between active warps.

I also tried to disable unrolling the iter domain while still splitting out by a factor of 2 for each thread. That resulted in reduced register usage, but the performance wasn't improved much as the stall stat was still as bad as the unroll case:

ptxas info    : 3 bytes gmem
ptxas info    : Compiling entry function '_ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_' for 'sm_80'
ptxas info    : Function properties for _ZN11CudaCodeGen7kernel1ENS_6TensorINS_6__halfELi1ELi1EEENS0_IS1_Li2ELi2EEES2_S2_S2_S3_
ptxas         .     0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 68 registers, 16 bytes smem, 464 bytes cmem[0], 8 bytes cmem[2]

Launch Parameters: BlockDim.x = 288, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = 4096, GridDim.y = -1, GridDim.z = -1, Smem Size = 1152
kernel1 run in 0.234496 ms, achieved: 718.429 GB/s

Overall, in this case, using iter unroll does not improve the register efficiency and fails to extract ILP, which explains the gap between 682 GB/s and 830 GB/s. As shown above, limiting the register usage to 72 increases the occupancy to 3 blocks, which improves the performance to almost on par with the no unroll case, however, it is clear the no unroll approach makes the most sense.

So, I think we can say this removal makes sense. Another thing we may want to change is removing this unswitch. I think I found it didn't work well for outer grid normalization, so that's why it has !is_outer_grid_persistence. I'll check if it should be globally disabled.

liqiangxl commented 1 year ago

Thanks for further investigation. what is "completely evenly divisible"? Does it means elements = iter_unroll_factor^N

I think the conclusions and actions we need to follow is: (1) remvoe iter_unroll_factor in #386 is good (2) Check other schedulers using iter_unroll_factor (3) if the scheduler needs iter_unroll_factor check unswitch anything else?

naoyam commented 1 year ago

what is "completely evenly divisible"? Does it means elements = iter_unroll_factor^N

No. The split transformations need to be divisible. Otherwise, the unswitch condition would be false for some threads, and due to the BSP nature, all threads would be blocked by the threads taking the slow path. Look at the actual generated code.

naoyam commented 1 year ago

I think the conclusions and actions we need to follow is: (1) remvoe iter_unroll_factor in #386 is good (2) Check other schedulers using iter_unroll_factor (3) if the scheduler needs iter_unroll_factor check unswitch anything else?

Yes. I'll take care of (2) and (3)