NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
245 stars 48 forks source link

InstanceNorm Perf for NHWC layout does not look optimal #443

Open kevinstephano opened 1 year ago

kevinstephano commented 1 year ago

It looks like the NHWC, with an outer reduction, takes double the time of the NCHW layout. Two things that look suspect about the NHWC kernel is that it uses 128 register compared to 36 registers for the NCHW layout suggesting perhaps a non-optimal heuristic choice. Additionally, the NHWC quantizes its blocks with 400 threads that leads to 12.5 warps.

NCHW Layout Repro:

import torch
from nvfuser import FusionDefinition, DataType

inputs = [
    torch.randn(256, 128, 28, 28, device='cuda'),
    torch.randn(1, 128, 1, 1, device='cuda'),
    torch.randn(1, 128, 1, 1, device='cuda'),
]

def func(fd: FusionDefinition) :
    T0 = fd.from_pytorch(inputs[0])
    weight = fd.from_pytorch(inputs[1])
    bias = fd.from_pytorch(inputs[2])
    S1 = fd.define_scalar(1.0e-5)
    V0 = T0.shape()
    var, mean = fd.ops.var_mean(T0, axes=[2, 3], correction=0, keepdim=True)
    T2 = (T0 - mean) / fd.ops.sqrt(var + S1)
    T4 = T2 * weight + bias 
    fd.add_output(T4)

with FusionDefinition() as fd:
    func(fd)

for _ in range(5):
    out = fd.execute(inputs)

Perf Result: (119 us)

 1766687582         119169     318  32768     1     1   128     1     1       36         0.000         0.001                                                     NVIDIA H100 PCIe (0)    1     7  CudaCodeGen::kernel1(CudaCodeGen::Tensor<float, (int)4, (int)4>, CudaCodeGen::Tensor<floa
t, (int)4,…

NHWC Layout Repro:

import torch
from nvfuser import FusionDefinition, DataType

inputs = [
    torch.randn(256, 28, 28, 128, device='cuda'),
    torch.randn(1, 1, 1, 128, device='cuda'),
    torch.randn(1, 1, 1, 128, device='cuda'),
]

def func(fd: FusionDefinition) :
    T0 = fd.from_pytorch(inputs[0])
    weight = fd.from_pytorch(inputs[1])
    bias = fd.from_pytorch(inputs[2])
    S1 = fd.define_scalar(1.0e-5)
    V0 = T0.shape()
    var, mean = fd.ops.var_mean(T0, axes=[1, 2], correction=0, keepdim=True)
    T2 = (T0 - mean) / fd.ops.sqrt(var + S1)
    T4 = T2 * weight + bias 
    fd.add_output(T4)

with FusionDefinition() as fd:
    func(fd)

for _ in range(5):
    out = fd.execute(inputs)

Scheduler Params:

===== Reduction Stats ========
total_reduction_numel: 784
total_iteration_numel: 32768
vectorize_factor: 4
n_tensor_inputs: 3
max_input_dtype_size: 4
max_persistent_buffer_size: 3136
max_multi_reduction_factor: 41
block(16, 25, 1)

===== Reduction Parameters ========
Tag: Outer persistent kernel heuristic.

Red On Slow Dim
Persistent Kernel
Batches per block: 8

Iteration Domain: blockIdx.x / threadIdx.x / multiple reductions per block / vectorize / factor 2
Inner Reduction Domain: cross block - threadIdx.y / persistent batch - 8 / unroll / factor 4
Launch Parameters: BlockDim.x = 16, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = -1, GridDim.y = -1, GridDim.z = -1, Smem Size = 0

Perf result: (225 us)

 2522963283         225409     314  1024     1     1    16    25     1      128         0.000         0.005                                                     NVIDIA H100 PCIe (0)    1     7  CudaCodeGen::kernel1(CudaCodeGen::Tensor<float, (int)4, (int)4>, CudaCodeGen::Tensor<float
, (int)4,…
naoyam commented 1 year ago

I looked at the NHWC repro on A100. Looks like it's a heuristics issue. We are using the block-parallel outer normalization scheduler, and I don't think we've done any specific tuning for a long time.

This case is probably one of the cases we would find with Liqiang's benchmarking effort. I'd expect we would find multiple performance issues and need to think about priority.

kevinstephano commented 1 year ago

Christian's suggestions.

[12:12] Christian Sarofeen Even just put in a really dirty hack of "if exactly these sizes for this normalization on this device do this". ​[12:13] Christian Sarofeen Even more on this, we could even skip the segmentation and heuristics directly and only do this for instance norm 3d ​[12:13] Christian Sarofeen We don't need generic fusion support, we just need an explicit nvFuser instance norm, and we can accumulate some small technical debt here.

kevinstephano commented 1 year ago

Sizes at types that are cared about:

Here are the cudnn layer files describing the shapes we are using https://gitlab-master.nvidia.com/dl/JoC/model_profiling/-/tree/main/cuDNN/mlperf/mxnet/image_segmentation. in the wild, only non-spatial cases (1x8x7, 1x8x1) are used. If that's hard to read/parse, three largest layers are (DxHxWxC):
- 128x128x128x32
- 64x64x64x64
- 32x32x32x128

dtypes are the standard set: fp16, amp, tf32, fp32, ordered according to priority
csarofeen commented 1 year ago

The dtypes are fp16 and fp32; amp and tf32 doesn't mean anything to us.

csarofeen commented 1 year ago

@jacobhinkle I also think we want to pull the logic you wrote in apex directly into an nvFuser contrib or something of the sort. We don't need the flexibility you had and can basically hard code the operator just for instance norm 3D at this time.