NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
245 stars 48 forks source link

Performance regression caused by leading meta operations on codegen kernel #2613

Closed jjsjann123 closed 1 month ago

jjsjann123 commented 1 month ago

Observed that even after alias analysis removed leading meta operations from the generated kernel, we are observing performance regression.

There're some slight segmentation differences between the two. But the big kernel is almost identical.

Repro script are as below.

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, False, False, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 1, 2, 0])
    T3 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, False, False, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 1, 2, 0])
    T4 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, False, False, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 1, 2, 0])
    T5 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, False, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 1, 2, 0])
    T6 = fd.ops.cast(T3, dtype=DataType.Float)
    T7 = fd.ops.neg(T6)
    T8 = fd.ops.cast(T7, dtype=DataType.BFloat16)
    T9 = fd.ops.cat([T8, T2], dim=-1)
    S10 = fd.define_scalar(2, dtype=DataType.Int)
    S11 = fd.define_scalar(16, dtype=DataType.Int)
    S12 = fd.define_scalar(1024, dtype=DataType.Int)
    S13 = fd.define_scalar(32, dtype=DataType.Int)
    V14 = fd.define_vector([S10, S11, S12, S13], dtype=DataType.Int)
    T15 = fd.ops.broadcast_in_dim(T0, shape=V14, broadcast_dims=[2, 3])
    T16 = fd.ops.cast(T5, dtype=DataType.Float)
    T17 = fd.ops.cast(T15, dtype=DataType.Float)
    T18 = fd.ops.mul(T16, T17)
    S19 = fd.define_scalar(2, dtype=DataType.Int)
    S20 = fd.define_scalar(16, dtype=DataType.Int)
    S21 = fd.define_scalar(1024, dtype=DataType.Int)
    S22 = fd.define_scalar(32, dtype=DataType.Int)
    V23 = fd.define_vector([S19, S20, S21, S22], dtype=DataType.Int)
    T24 = fd.ops.broadcast_in_dim(T1, shape=V23, broadcast_dims=[2, 3])
    T25 = fd.ops.cast(T9, dtype=DataType.Float)
    T26 = fd.ops.cast(T24, dtype=DataType.Float)
    T27 = fd.ops.mul(T25, T26)
    T28 = fd.ops.add(T18, T27)
    T29 = fd.ops.cast(T28, dtype=DataType.BFloat16)
    T30 = fd.ops.cat([T29, T4], dim=-1)
    fd.add_output(T30)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn((32768,), dtype=torch.bfloat16, device='cuda:0').as_strided((1024, 32), (32, 1)),
    torch.randn((32768,), dtype=torch.bfloat16, device='cuda:0').as_strided((1024, 32), (32, 1)),
    torch.randn((1179568,), dtype=torch.bfloat16, device='cuda:0').as_strided((2, 16, 1024, 16), (589824, 32, 576, 1)),
    torch.randn((1179568,), dtype=torch.bfloat16, device='cuda:0').as_strided((2, 16, 1024, 16), (589824, 32, 576, 1)),
    torch.randn((0,), dtype=torch.bfloat16, device='cuda:0').as_strided((2, 16, 1024, 0), (589824, 32, 576, 1)),
    torch.randn((1179584,), dtype=torch.bfloat16, device='cuda:0').as_strided((2, 16, 1024, 32), (589824, 32, 576, 1)),
]
fd.execute(inputs)

vs

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    S3 = fd.define_scalar(2, dtype=DataType.Int)
    S4 = fd.define_scalar(1024, dtype=DataType.Int)
    S5 = fd.define_scalar(1, dtype=DataType.Int)
    S6 = fd.define_scalar(18, dtype=DataType.Int)
    S7 = fd.define_scalar(32, dtype=DataType.Int)
    V8 = fd.define_vector([S3, S4, S5, S6, S7], dtype=DataType.Int)
    T9 = fd.ops.reshape(T2, new_shape=V8)
    T10 = fd.ops.permute(T9, dims=[0, 2, 3, 1, 4])
    T11 = fd.ops.slice(T10, start_indices=[0, 0, 0, 0, 0], end_indices=[2, 1, 16, 1024, 32], strides=[1, 1, 1, 1, 1])
    S12 = fd.define_scalar(2, dtype=DataType.Int)
    S13 = fd.define_scalar(16, dtype=DataType.Int)
    S14 = fd.define_scalar(1024, dtype=DataType.Int)
    S15 = fd.define_scalar(32, dtype=DataType.Int)
    V16 = fd.define_vector([S12, S13, S14, S15], dtype=DataType.Int)
    T17 = fd.ops.reshape(T11, new_shape=V16)
    T18 = fd.ops.slice(T17, start_indices=[0, 0, 0, 0], end_indices=[2, 16, 1024, 16], strides=[1, 1, 1, 1])
    T19 = fd.ops.slice(T17, start_indices=[0, 0, 0, 16], end_indices=[2, 16, 1024, 32], strides=[1, 1, 1, 1])
    T20 = fd.ops.slice(T17, start_indices=[0, 0, 0, 0], end_indices=[2, 16, 1024, 0], strides=[1, 1, 1, 1])

    T17 = fd.ops.segment_set(T17)
    T18 = fd.ops.segment_set(T18)
    T19 = fd.ops.segment_set(T19)
    T20 = fd.ops.segment_set(T20)

    T21 = fd.ops.cast(T19, dtype=DataType.Float)
    T22 = fd.ops.neg(T21)
    T23 = fd.ops.cast(T22, dtype=DataType.BFloat16)
    T24 = fd.ops.cat([T23, T18], dim=-1)
    S25 = fd.define_scalar(2, dtype=DataType.Int)
    S26 = fd.define_scalar(16, dtype=DataType.Int)
    S27 = fd.define_scalar(1024, dtype=DataType.Int)
    S28 = fd.define_scalar(32, dtype=DataType.Int)
    V29 = fd.define_vector([S25, S26, S27, S28], dtype=DataType.Int)
    T30 = fd.ops.broadcast_in_dim(T0, shape=V29, broadcast_dims=[2, 3])
    T31 = fd.ops.cast(T17, dtype=DataType.Float)
    T32 = fd.ops.cast(T30, dtype=DataType.Float)
    T33 = fd.ops.mul(T31, T32)
    S34 = fd.define_scalar(2, dtype=DataType.Int)
    S35 = fd.define_scalar(16, dtype=DataType.Int)
    S36 = fd.define_scalar(1024, dtype=DataType.Int)
    S37 = fd.define_scalar(32, dtype=DataType.Int)
    V38 = fd.define_vector([S34, S35, S36, S37], dtype=DataType.Int)
    T39 = fd.ops.broadcast_in_dim(T1, shape=V38, broadcast_dims=[2, 3])
    T40 = fd.ops.cast(T24, dtype=DataType.Float)
    T41 = fd.ops.cast(T39, dtype=DataType.Float)
    T42 = fd.ops.mul(T40, T41)
    T43 = fd.ops.add(T33, T42)
    T44 = fd.ops.cast(T43, dtype=DataType.BFloat16)
    T45 = fd.ops.cat([T44, T20], dim=-1)
    fd.add_output(T45)

with FusionDefinition() as fd:
    nvfuser_fusion_id1(fd)

inputs = [
    torch.randn((32768,), dtype=torch.bfloat16, device='cuda:0').as_strided((1024, 32), (32, 1)),
    torch.randn((32768,), dtype=torch.bfloat16, device='cuda:0').as_strided((1024, 32), (32, 1)),
    torch.randn((1179648,), dtype=torch.bfloat16, device='cuda:0').as_strided((2, 1024, 576), (589824, 576, 1)),
]
fd.execute(inputs)

The second script is segmented as below. So if we disgard the leading no-op segment, the computation looks the same as with the first script.

group details:
g{(no_op)
inputs:
T2_g[ iS202{2}, iS203{1024}, iS6{i7} ] __bfloat
outputs:
T10_g[ iS167{2}, iS46{16}, iS168{1024}, iS169{32} ] __bfloat
T11_g[ iS190{2}, iS50{16}, iS191{1024}, iS52{16} ] __bfloat
T12_g[ iS177{2}, iS54{16}, iS178{1024}, iS56{16} ] __bfloat

T33_l[ iS204{2}, iS205{1024}, iS147{18}rf, iS148{( ceilDiv(i7, 18) )}rf ] = view( T2_g[ iS202{2}, iS203{1024}, iS6{i7} ] )
(42)
T34_g[ iS206{2}, iS207{1024}, bS151{1}, iS152{18}, iS153{( ceilDiv(i7, 18) )} ]
   = broadcast( T33_l[ iS204{2}, iS205{1024}, iS147{18}rf, iS148{( ceilDiv(i7, 18) )}rf ] )
(43)
T4_g[ iS154{2}, bS156{1}, iS157{18}, iS155{1024}, iS158{32} ]
   = Set.Permute( T34_g[ iS206{2}, iS207{1024}, bS151{1}, iS152{18}, iS153{( ceilDiv(i7, 18) )} ], cache_op=Streaming )
(44)
T5_l[ iS159{2}rf, bS160{1}rf, iS23{16}rf, iS162{1024}rf, iS163{32}rf ]
   = slice( T4_g[ iS154{2}, bS156{1}, iS157{18}, iS155{1024}, iS158{32} ], { {0, 2, 1} {0, 1, 1} {0, 16, 1} {0, 1024, 1} {0, 32, 1}
 } )
(3)
T6_g[ iS164{2}, iS27{16}, iS165{1024}, iS166{32} ]
   = squeeze( T5_l[ iS159{2}rf, bS160{1}rf, iS23{16}rf, iS162{1024}rf, iS163{32}rf ] )
(4)
T7_g[ iS187{2}rf, iS31{16}rf, iS188{1024}rf, iS34{16}rf ]
   = slice( T6_g[ iS164{2}, iS27{16}, iS165{1024}, iS166{32} ], { {0, 2, 1} {0, 16, 1} {0, 1024, 1} {0, 16, 1} } )
(6)
T11_g[ iS190{2}, iS50{16}, iS191{1024}, iS52{16} ]
   = SegmenterSet( T7_g[ iS187{2}rf, iS31{16}rf, iS188{1024}rf, iS34{16}rf ] )
(12)
T8_g[ iS174{2}rf, iS36{16}rf, iS175{1024}rf, iS39{16}rf ]
   = slice( T6_g[ iS164{2}, iS27{16}, iS165{1024}, iS166{32} ], { {0, 2, 1} {0, 16, 1} {0, 1024, 1} {16, 32, 1} } )
(8)
T12_g[ iS177{2}, iS54{16}, iS178{1024}, iS56{16} ]
   = SegmenterSet( T8_g[ iS174{2}rf, iS36{16}rf, iS175{1024}rf, iS39{16}rf ] )
(13)
T10_g[ iS167{2}, iS46{16}, iS168{1024}, iS169{32} ]
   = SegmenterSet( T6_g[ iS164{2}, iS27{16}, iS165{1024}, iS166{32} ] )
(11)
i444 = ceilDiv(i7, 18);
(40)
}

g{(pointwise)
inputs:
T12_g[ iS177{2}, iS54{16}, iS178{1024}, iS56{16} ] __bfloat
outputs:
T16_g[ iS183{2}, iS70{16}, iS184{1024}, iS72{16} ] __bfloat

T14_g[ iS179{2}, iS62{16}, iS180{1024}, iS64{16} ]
   = __bfloat2float(T12_g[ iS177{2}, iS54{16}, iS178{1024}, iS56{16} ]);
(15)
T15_g[ iS181{2}, iS66{16}, iS182{1024}, iS68{16} ]
   = -T14_g[ iS179{2}, iS62{16}, iS180{1024}, iS64{16} ];
(16)
T16_g[ iS183{2}, iS70{16}, iS184{1024}, iS72{16} ]
   = __float2bfloat(T15_g[ iS181{2}, iS66{16}, iS182{1024}, iS68{16} ]);
(17)
}

g{(pointwise)
inputs:
T0_g[ iS208{1024}, iS209{32} ] __bfloat
T1_g[ iS212{1024}, iS213{32} ] __bfloat
T10_g[ iS167{2}, iS46{16}, iS168{1024}, iS169{32} ] __bfloat
T11_g[ iS190{2}, iS50{16}, iS191{1024}, iS52{16} ] __bfloat
T16_g[ iS183{2}, iS70{16}, iS184{1024}, iS72{16} ] __bfloat
outputs:
T32_g[ iS201{2}, iS140{16}, iS141{1024}, iS142{32} ] __bfloat

T25_l[ bS107{1}, bS108{1}, iS214{1024}, iS215{32} ]
   = broadcast( T1_g[ iS212{1024}, iS213{32} ] )
(31)
T26_g[ bS111{1 ex 2}, bS112{1 ex 16}, iS113{1024}, iS114{32} ] = expand( T25_l[ bS107{1}, bS108{1}, iS214{1024}, iS215{32} ], {2, 16, 1024, 32} )
(80)
T28_g[ bS119{1 ex 2}, bS120{1 ex 16}, iS121{1024}, iS122{32} ]
   = __bfloat2float(T26_g[ bS111{1 ex 2}, bS112{1 ex 16}, iS113{1024}, iS114{32} ]);
(34)
T22_l[ iS170{2}, iS96{16}, iS171{1024}, iS172{32} ]
   = __bfloat2float(T10_g[ iS167{2}, iS46{16}, iS168{1024}, iS169{32} ]);
(28)
T20_l[ bS87{1}, bS88{1}, iS210{1024}, iS211{32} ]
   = broadcast( T0_g[ iS208{1024}, iS209{32} ] )
(26)
T21_g[ bS91{1 ex 2}, bS92{1 ex 16}, iS93{1024}, iS94{32} ] = expand( T20_l[ bS87{1}, bS88{1}, iS210{1024}, iS211{32} ], {2, 16, 1024, 32} )
(79)
T23_g[ bS99{1 ex 2}, bS100{1 ex 16}, iS101{1024}, iS102{32} ]
   = __bfloat2float(T21_g[ bS91{1 ex 2}, bS92{1 ex 16}, iS93{1024}, iS94{32} ]);
(29)
T24_g[ iS173{2}, iS104{16}, iS105{1024}, iS106{32} ]
   = T22_l[ iS170{2}, iS96{16}, iS171{1024}, iS172{32} ]
   * T23_g[ bS99{1 ex 2}, bS100{1 ex 16}, iS101{1024}, iS102{32} ];
(30)
T17_l[ iS185{2}, iS74{16}, iS186{1024}, iS77{32}rf ]
   = pad( T16_g[ iS183{2}, iS70{16}, iS184{1024}, iS72{16} ], {0, 0, 0, 0, 0, 0, 0, 16} )
(20)
i302 = 0 + 16;
(18)
T18_g[ iS192{2}, iS79{16}, iS193{1024}, iS82{( ( 0 + 16 ) + 16 )}rf ]
   = pad( T11_g[ iS190{2}, iS50{16}, iS191{1024}, iS52{16} ], {0, 0, 0, 0, 0, 0, i302, 0} )
(24)
T19_g[ iS194{2}, iS84{16}, iS195{1024}, iS86{32} ]
   = cat( T17_l[ iS185{2}, iS74{16}, iS186{1024}, iS77{32}rf ], T18_g[ iS192{2}, iS79{16}, iS193{1024}, iS82{( ( 0 + 16 ) + 16 )}rf ], 3 )
(25)
T27_g[ iS196{2}, iS116{16}, iS197{1024}, iS118{32} ]
   = __bfloat2float(T19_g[ iS194{2}, iS84{16}, iS195{1024}, iS86{32} ]);
(33)
T29_l[ iS198{2}, iS124{16}, iS125{1024}, iS126{32} ]
   = T27_g[ iS196{2}, iS116{16}, iS197{1024}, iS118{32} ]
   * T28_g[ bS119{1 ex 2}, bS120{1 ex 16}, iS121{1024}, iS122{32} ];
(35)
T30_g[ iS199{2}, iS128{16}, iS129{1024}, iS130{32} ]
   = T24_g[ iS173{2}, iS104{16}, iS105{1024}, iS106{32} ]
   + T29_l[ iS198{2}, iS124{16}, iS125{1024}, iS126{32} ];
(36)
T31_g[ iS200{2}, iS132{16}, iS133{1024}, iS134{32} ]
   = __float2bfloat(T30_g[ iS199{2}, iS128{16}, iS129{1024}, iS130{32} ]);
(37)
T32_g[ iS201{2}, iS140{16}, iS141{1024}, iS142{32} ]
   = Set( T31_g[ iS200{2}, iS132{16}, iS133{1024}, iS134{32} ], cache_op=Streaming )
(39)
i326 = i302 + 16;
(22)
}
 ** CUDA GPU Kernel Summary (cuda_gpu_kern_sum):

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     69.5            40032          5    8006.4   10880.0      3520     10976       3994.2  void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat…
     24.4            14048          1   14048.0   14048.0     14048     14048          0.0  <unnamed>::nvfuser_pointwise_f0_c1_r0_g0(<unnamed>::Tensor<<unnamed>::__bfloat, (int)4, (int)4>, <u…
      6.1             3520          1    3520.0    3520.0      3520      3520          0.0  <unnamed>::nvfuser_pointwise_f0_c1_r0_g1(<unnamed>::Tensor<<unnamed>::__bfloat, (int)4, (int)4>, <u…
 ** CUDA GPU Kernel Summary (cuda_gpu_kern_sum):

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     71.3            53056          1   53056.0   53056.0     53056     53056          0.0  <unnamed>::nvfuser_pointwise_f0_c1_r0_g2(<unnamed>::Tensor<<unnamed>::__bfloat, (int)2, (int)2>, <u…
     24.3            18112          3    6037.3    3552.0      3520     11040       4332.5  void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat…
      4.3             3232          1    3232.0    3232.0      3232      3232          0.0  <unnamed>::nvfuser_pointwise_f0_c1_r0_g0(<unnamed>::Tensor<<unnamed>::__bfloat, (int)4, (int)4>, <u…

The kernel time we are getting from these two programs are very different.

jjsjann123 commented 1 month ago

This issue blocks https://github.com/Lightning-AI/lightning-thunder/pull/731 and is one of the reasons of the huge performance regression from qkv_split_rope.

jjsjann123 commented 1 month ago

Looking at the generated kernel, we are missing out vectorized load on non-padded tensor inputs. Launch params is also slightly different, but that might as well be a side effect from the lack of vectorization.

jjsjann123 commented 1 month ago

So I think the next step here is take a quick look at vectorization analysis.

An orthogonal topic is to support vectorization on PadOp, which will be required in the presegmentation pass where we will be aggressively pushing out PadOp to avoid kernel segmentation.

jjsjann123 commented 1 month ago

How embarrassing it is. the large regression might be coming from allocation order inference.

https://github.com/NVIDIA/Fuser/pull/2630 seems to resolve the larger rope regression locally for me. I'll rerun the benchmark tomorrow. :crossed_fingers:

jjsjann123 commented 1 month ago

Follow up with the benchmark to check #2630's performance impact: https://gist.github.com/jjsjann123/87345938c0dd0c12b83c2b8f4c42fa9c

Looks like it helped with the forward part at least. But there are still quite some regression remaining on backward.

jjsjann123 commented 1 month ago

looks like we are generating lots of pointwise kernels on backward rope. I'm suspecting those are just alias analysis not aggressively pushing things out. Will try my luck with #2608

jjsjann123 commented 1 month ago

:cry: backend regression is more than just alias stuff. I'm seeing kernel performance issue as well even with #2608 and #2630

The backward fusion pattern looks very similar, except some slice at the beginning and the permute at the end. https://gist.github.com/jjsjann123/87345938c0dd0c12b83c2b8f4c42fa9c?permalink_comment_id=5127263#gistcomment-5127263 I'm suspecting it's the permute that's giving us issue on vectorization again. I'll confirm that.