NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

LayerNorm fusion in Diffusion Transformer has worse performance than expected #2146

Closed parthmannan closed 3 weeks ago

parthmannan commented 7 months ago

nvFuser generated code for a fusion block present in DiT has worse than expected performance. The subgraph is performing a LayerNorm + + Mul + Add + Add computation as shown in the code below. nvFuser code was generated using Lightning Thunder compiler for PyTorch.

Thunder Forward performance: 99us
Thunder Backward performance: 1034us

Torch.Compile Forward performance: 115us
Torch.Compile Backward performance: 457us

Below is the subgraph highlighted.

Screenshot 2024-03-20 at 1 38 06 PM

Reproducible script (requires Thunder installed)

import torch
import torch.nn as nn
import torch.nn.functional as F
import thunder

class Net(nn.Module):
    def __init__(
        self,
        config):
        super(Net, self).__init__()
        self.ln = nn.LayerNorm(config['hidden_units'])

    def modulate(self, x, scale, shift):
        return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

    def forward(self, x, scale, shift):
        x   = self.ln(x)
        out = self.modulate(x, scale, shift)
        return out

config = {'hidden_units': 1024, 'seq_len': 1024, 'batch_size': 56}

net = Net(config)
net.cuda()
net.to(dtype=torch.bfloat16)

from thunder.executors.sdpaex import sdpa_ex
executors = [sdpa_ex, thunder.nvfuser_executor, thunder.pytorch_executor]

network_fn = thunder.jit(net, executors=executors)
#network_fn = torch.compile(net)

bench_iters = 10
profile_batch = 5

input_shapes    = [(config['batch_size'], config['seq_len'], config['hidden_units']),
                   (config['batch_size'], config['hidden_units']),
                   (config['batch_size'], config['hidden_units'])]

def generate_io_tensor(net, input_shapes):
    input_tensors = []

    for shape in input_shapes:
        tensor = torch.rand(shape, dtype=torch.bfloat16, requires_grad=True, device='cuda')
        input_tensors.append(tensor)

    target_tensor_size = net(*input_tensors).size()
    target_tensor = torch.rand(target_tensor_size, dtype=torch.bfloat16, device='cuda')

    return input_tensors, target_tensor

for idx in range(bench_iters):
    input_tensors, target_tensor = generate_io_tensor(net, input_shapes)

    ## Profiling code BEGIN
    if idx == profile_batch:
        print("BEGIN PROFILING ITERATION")
        torch.cuda.cudart().cudaProfilerStart()
        torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
    ## Profiling code END

    outputs = network_fn(*input_tensors)
    outputs.backward(target_tensor)

    # #UNCOMMENT FOR DEBUGGING THUNDER
    # # Grab the execution traces for Thunder debugging
    # fwd_trace = thunder.last_traces(network_fn)[-1]
    # bwd_trace = thunder.last_backward_traces(network_fn)[-1]

    # print("forward trace")
    # print(fwd_trace)
    # print("forward nvFuser")
    # for k, v in fwd_trace.python_ctx().items():
    #     if 'nvFusion' in k:
    #         print(v.last_used)
    #         print(v.last_used.last_cuda_code())

    # print("backward trace")
    # print(bwd_trace)
    # print("backward nvFuser")
    # for k, v in bwd_trace.python_ctx().items():
    #     if 'nvFusion' in k:
    #         print(v.last_used)
    #         print(v.last_used.last_cuda_code())

    # import sys
    # sys.exit(0)

    ## Profiling code BEGIN
    if idx == profile_batch:
        print("END PROFILING ITERATION")
        torch.cuda.cudart().cudaProfilerStop()
        torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)

Also attached is the log file which has the nvFuser generated code for forward and backward of this subgraph.

thunder_layernorm_scale_add_trace.log

rdspring1 commented 7 months ago

Given Input Tensors:

Some observations about the backward trace:

  1. The LayerNorm backward operation is handled by the Inner-Outer-Persistent scheduler.
  2. The backward function for x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) is handled by the Reduction scheduler because of the expanded, broadcast dimension at dimension 1.
  3. When the two are combined together, the Inner-Outer-Persistent scheduler fails for the following reason:

Scheduler _inner_outer_persistent_ ***rejected*** because : to use combined reduction, inner reduction tensor should be [I,I,...,R,R] and outer reduction tensor should be [R,R,...,I,I]

  1. After failing on the original fusion, the segmentation runs and generates the four segmented fusions.
  2. There isn't likely anything wrong with the segmentation algorithm. It just isn't guaranteed to find the optimal sub-fusion partition.
parthmannan commented 7 months ago

Thanks Ryan! Is there a way of re-writing the model code/Thunder trace that would help the segmentation?

wujingyue commented 6 months ago

After failing on the original fusion, the segmentation runs and generates the four segmented fusions. There isn't likely anything wrong with the segmentation algorithm. It just isn't guaranteed to find the optimal sub-fusion partition.

This is quite unfortunate. I suspect something similar ot https://github.com/NVIDIA/Fuser/issues/1707 is happening -- a combination of a suboptimal segment merging order and/or canSchedule not being "monotonic".

wujingyue commented 6 months ago

Is there a way of re-writing the model code/Thunder trace that would help the segmentation?

No, but I think a possible workaround can be to modify segmentation to first attempt to segment only at segment_set and to write a pre-seg pass to add segment_set to the right place for this particular pattern. Along this line, I'm also open to expose segment_set to Thunder, but it feels like lots of work to expose a knob that the user doesn't know how to use.

liqiangxl commented 6 months ago

If we do the following two changes we can get 2 kernels

The total kernel time is reduced from 1.02 ms to 0.63 ms. Before change: 4 kernels: 0.342+0.311+0.209+0.157 = 1.019 ms (1) outer reduction (2) outer reduction (3) inner reduction (4) pointwise

After change: two kernels: 0.315 + 0.317 = 0.632 ms (1) outer reduction (2) innter_outer_persistent

Still lower than target, but we can further improve through the following two works:

jacobhinkle commented 6 months ago

No, but I think a possible workaround can be to modify segmentation to first attempt to segment only at segment_set and to write a pre-seg pass to add segment_set to the right place for this particular pattern.

Interesting. This pass might not work if there were not enough segment_sets to form a full graph cut, but in such case it could just proceed with scheduling as usual. This pass could just replace the current stage where we check scheduling the complete fusion.

wujingyue commented 6 months ago

Thanks for the investigation, @liqiangxl !

If we do the following two changes we can get 2 kernels

  • (1) add a inner_outer_reduction scheduler. I only added a placeholder, it is used to make the segmenter believe it can schedule fusion with inner reduction and outer reduction. Then, it will further merge more exprs to create a fusion using inner_outer_persistent scheduler.
  • (2) remove restirction of hasNonNormalizePostReductionBCast in inner_outer_persistent scheduler.

I suspect (2) might be sufficient. When I read the segmenter logging,

$ git checkout wjy/layernorm
$ NVFUSER_DUMP=segmenter_logging python nvfuser_repro.py

the following attempt to merge inner_reduction and pointwise into inner_persistent failed due to "unsupported post reduction normalization". I suspect if we remove that restriction, we can fuse inner_reduction and pointwise into inner_persistent, and then merge that and outer_reduction into inner_outer_persistent. Does that make sense to you?

**Segmenter** Considering fusion:
T21_l[ iS279{56}, bS63{1}, iS280{1024} ]
   = double(1)
   + T20_g[ iS277{56}, bS60{1}, iS278{1024} ];
T22_g[ iS281{56}, bS66{1}, iS282{1024} ]
   = __float2bfloat(T21_l[ iS279{56}, bS63{1}, iS280{1024} ]);
T23_g[ iS283{56}, bS69{1}, iS284{1024} ]
   = Set( T22_g[ iS281{56}, bS66{1}, iS282{1024} ], cache_op=Streaming )
T24_l[ iS71{56}, bS72{1 ex 1024}, iS73{1024} ] = expand( T23_g[ iS283{56}, bS69{1}, iS284{1024} ], {56, 1024, 1024} )
T25_g[ iS74{56}, bS75{1 ex 1024}, iS76{1024} ]
   = __bfloat2float(T24_l[ iS71{56}, bS72{1 ex 1024}, iS73{1024} ]);
T34_g[ iS98{56}, iS285{1024}, iS100{1024} ]
   = T25_g[ iS74{56}, bS75{1 ex 1024}, iS76{1024} ]
   * T26_g[ iS239{56}, iS240{1024}, iS241{1024} ];
T45_g[ iS126{56}, iS287{1024}, iS128{1024} ]
   = T16_g[ bS47{1 ex i0}, bS48{1 ex i1}, iS266{1024} ]
   * T34_g[ iS98{56}, iS285{1024}, iS100{1024} ];
T49_g[ iS136{56}, iS137{1024}, iS138{1024} ]
   = T14_g[ iS41{56}, iS42{1024}, bS43{1 ex 1024} ]
   * T45_g[ iS126{56}, iS287{1024}, iS128{1024} ];
T54_l[ iS151{56}, iS152{1024}, iS153{1024} ]
   = -T49_g[ iS136{56}, iS137{1024}, iS138{1024} ];
T55_g[ iS154{56}, iS155{1024}, rS156{1024} ]
   = reduction( T54_l[ iS151{56}, iS152{1024}, iS153{1024} ], op = add, initial value = float(0), allreduce = false )
T56_g[ iS157{56}, iS158{1024}, bS159{1} ]
   = broadcast( T55_g[ iS154{56}, iS155{1024}, rS156{1024} ] )
T57_l[ iS160{56}, iS161{1024}, bS162{1} ]
   = Set( T56_g[ iS157{56}, iS158{1024}, bS159{1} ], cache_op=Streaming )
T61_g[ iS172{56}, iS173{1024} ]
   = squeeze( T57_l[ iS160{56}, iS161{1024}, bS162{1} ] )
T63_g[ iS176{56}, iS177{1024}, bS178{1} ]
   = broadcast( T61_g[ iS172{56}, iS173{1024} ] )
T64_l[ iS179{56}, iS180{1024}, bS181{1} ]
   = Set( T63_g[ iS176{56}, iS177{1024}, bS178{1} ], cache_op=Streaming )
T65_g[ iS182{56}, iS183{1024}, bS184{1} ]
   = Set( T64_l[ iS179{56}, iS180{1024}, bS181{1} ], cache_op=Streaming )
T66_l[ iS185{56}, iS186{1024}, bS187{1 ex 1024} ] = expand( T65_g[ iS182{56}, iS183{1024}, bS184{1} ], {56, 1024, 1024} )
T67_g[ iS188{56}, iS189{1024}, bS190{1 ex 1024} ]
   = double(0.00097656200000000005)
   * T66_l[ iS185{56}, iS186{1024}, bS187{1 ex 1024} ];
T50_l[ iS139{56}, iS140{1024}, iS141{1024} ]
   = T12_g[ iS35{56}, iS36{1024}, iS259{1024} ]
   * T45_g[ iS126{56}, iS287{1024}, iS128{1024} ];
T51_g[ iS142{56}, iS143{1024}, rS144{1024} ]
   = reduction( T50_l[ iS139{56}, iS140{1024}, iS141{1024} ], op = add, initial value = float(0), allreduce = false )
T52_g[ iS145{56}, iS146{1024}, bS147{1} ]
   = broadcast( T51_g[ iS142{56}, iS143{1024}, rS144{1024} ] )
T53_l[ iS148{56}, iS149{1024}, bS150{1} ]
   = Set( T52_g[ iS145{56}, iS146{1024}, bS147{1} ], cache_op=Streaming )
T58_g[ iS163{56}, iS164{1024}, bS165{1} ]
   = double(-0.5)
   * T53_l[ iS148{56}, iS149{1024}, bS150{1} ];
T59_g[ iS288{56}, iS289{1024}, bS168{1} ]
   = pow(T5_g[ iS260{56}, iS261{1024}, bS16{1} ]
  , double(3));
T60_g[ iS169{56}, iS170{1024}, bS171{1} ]
   = T58_g[ iS163{56}, iS164{1024}, bS165{1} ]
   * T59_g[ iS288{56}, iS289{1024}, bS168{1} ];
T62_l[ iS174{56}, iS175{1024} ]
   = squeeze( T60_g[ iS169{56}, iS170{1024}, bS171{1} ] )
T68_g[ iS191{56}, iS192{1024}, bS193{1} ]
   = broadcast( T62_l[ iS174{56}, iS175{1024} ] )
T69_g[ iS194{56}, iS195{1024}, bS196{1} ]
   = Set( T68_g[ iS191{56}, iS192{1024}, bS193{1} ], cache_op=Streaming )
T70_l[ iS197{56}, iS198{1024}, bS199{1} ]
   = Set( T69_g[ iS194{56}, iS195{1024}, bS196{1} ], cache_op=Streaming )
T71_g[ iS200{56}, iS201{1024}, bS202{1 ex 1024} ] = expand( T70_l[ iS197{56}, iS198{1024}, bS199{1} ], {56, 1024, 1024} )
T76_g[ iS215{56}, iS216{1024}, bS217{1 ex 1024} ]
   = double(2)
   * T71_g[ iS200{56}, iS201{1024}, bS202{1 ex 1024} ];
T72_l[ iS290{56}, iS291{1024}, bS205{1} ]
   = broadcast( T4_g[ iS255{56}, iS256{1024} ] )
T73_g[ iS206{56}, iS207{1024}, bS208{1} ]
   = Set( T72_l[ iS290{56}, iS291{1024}, bS205{1} ], cache_op=Streaming )
T74_g[ iS209{56}, iS210{1024}, bS211{1} ]
   = Set( T73_g[ iS206{56}, iS207{1024}, bS208{1} ], cache_op=Streaming )
T75_l[ iS212{56}, iS213{1024}, bS214{1 ex 1024} ] = expand( T74_g[ iS209{56}, iS210{1024}, bS211{1} ], {56, 1024, 1024} )
T77_g[ iS218{56}, iS219{1024}, iS292{1024} ]
   = T7_g[ iS252{56}, iS253{1024}, iS254{1024} ]
   - T75_l[ iS212{56}, iS213{1024}, bS214{1 ex 1024} ];
T78_l[ iS221{56}, iS222{1024}, iS293{1024} ]
   = T76_g[ iS215{56}, iS216{1024}, bS217{1 ex 1024} ]
   * T77_g[ iS218{56}, iS219{1024}, iS292{1024} ];
d371 = reciprocal(double(1024));
T79_g[ iS224{56}, iS225{1024}, iS294{1024} ]
   = T78_l[ iS221{56}, iS222{1024}, iS293{1024} ]
   * d371;
T80_g[ iS227{56}, iS228{1024}, iS295{1024} ]
   = T67_g[ iS188{56}, iS189{1024}, bS190{1 ex 1024} ]
   + T79_g[ iS224{56}, iS225{1024}, iS294{1024} ];
T81_g[ iS230{56}, iS231{1024}, iS232{1024} ]
   = T49_g[ iS136{56}, iS137{1024}, iS138{1024} ]
   + T80_g[ iS227{56}, iS228{1024}, iS295{1024} ];
T82_g[ iS233{56}, iS234{1024}, iS235{1024} ]
   = __float2bfloat(T81_g[ iS230{56}, iS231{1024}, iS232{1024} ]);

Scheduler _no_op_ ***rejected*** because : reduction of non-zero elements is not supported
Scheduler _matmul_ ***rejected*** because : Matmul scheduler supports fusions only with a single mma opor supports a mul-sum pair which can be replaced with a mma op
Scheduler _reduction_ ***rejected*** because : need persistent buffers that reduction scheduler doesn't handle
Scheduler _transpose_ ***rejected*** because : no support for reduction ops
Scheduler _pointwise_ ***rejected*** because : no support for reduction ops
Scheduler _inner_persistent_ ***rejected*** because : unsupported post reduction normalization
Scheduler _outer_persistent_ ***rejected*** because : schedule_heuristic doesn't match with reduction type `inner_persistent`.
Scheduler _inner_outer_persistent_ ***rejected*** because : heuristicType() doesn't match with reduction type `inner_persistent`.
wujingyue commented 6 months ago

I'd love to know what hasNonNormalizePostReductionBCast does exactly. Some broadcasts seem really trivial and should be gotten rid of before segmentation. For example,

T61_g[ iS172{56}, iS173{1024} ]
   = squeeze( T57_l[ iS160{56}, iS161{1024}, bS162{1} ] )
T63_g[ iS176{56}, iS177{1024}, bS178{1} ]
   = broadcast( T61_g[ iS172{56}, iS173{1024} ] )

and

T62_l[ iS174{56}, iS175{1024} ]
   = squeeze( T60_g[ iS169{56}, iS170{1024}, bS171{1} ] )
T68_g[ iS191{56}, iS192{1024}, bS193{1} ]
   = broadcast( T62_l[ iS174{56}, iS175{1024} ] )

That leaves only three broadcasts:

  1. T56 that immediately follows T55, a sum.
  2. T52 that immediately follows T51, a sum.
  3. T72 that's applied on a segment input.

T72 isn't reachable from any reduction so shouldn't affect hasNonNormalizePostReductionBCast. T56 and T52 look like the very normal pattern of normalization after reduction.

liqiangxl commented 6 months ago

As the function name implies, hasNonNormalizePostReductionBCast looks for post reduction resolved broadcast domain. Then it checks whether it follows the normalization pattern by propagate backwards. Track the id's that were resolved and make sure there's a mapping to a TensorView before a reduction. see issue https://github.com/NVIDIA/Fuser/issues/2046 we had a discussion about removing these broadcast-squeeze and squeeze-broadcast ops. will ping you in teams.

liqiangxl commented 6 months ago
  • hasNonNormalizePostReductionBCast

Thanks for the investigation, @liqiangxl !

If we do the following two changes we can get 2 kernels

  • (1) add a inner_outer_reduction scheduler. I only added a placeholder, it is used to make the segmenter believe it can schedule fusion with inner reduction and outer reduction. Then, it will further merge more exprs to create a fusion using inner_outer_persistent scheduler.
  • (2) remove restirction of hasNonNormalizePostReductionBCast in inner_outer_persistent scheduler.

I suspect (2) might be sufficient. When I read the segmenter logging,

$ git checkout wjy/layernorm
$ NVFUSER_DUMP=segmenter_logging python nvfuser_repro.py

the following attempt to merge inner_reduction and pointwise into inner_persistent failed due to "unsupported post reduction normalization". I suspect if we remove that restriction, we can fuse inner_reduction and pointwise into inner_persistent, and then merge that and outer_reduction into inner_outer_persistent. Does that make sense to you?

Yes, that should also work. If we can remove broadcast-squeeze and squeeze-broadcast ops, we don't even need to change hasNonNormalizePostReductionBCast.

wujingyue commented 6 months ago

@liqiangxl How about first trying to remove squeeze-broadcast before segmentation? I think it's an optimization that's useful anyway given Thunder uses broadcast_in_dims extensively. Then, we can check how things improve and decide next steps.

liqiangxl commented 6 months ago

@liqiangxl How about first trying to remove squeeze-broadcast before segmentation? I think it's an optimization that's useful anyway given Thunder uses broadcast_in_dims extensively. Then, we can check how things improve and decide next steps.

Sounds good to me! I added a repro file to your branch wjy/layernorm, after manually remove squeeze-broadcast, it can generate two kernels.

liqiangxl commented 5 months ago

Fixes are merged in the main branch, current performance measured on H100: two kernels: 0.265ms (outer reduction) + 0.329ms (innter_outer_persistent) = 0.594 ms Previous performance: 1.0 ms

Further optimization opportunities: There are 3 inter-segment tensors, they are output of segmentation-1 and input of segmentation-2.

T12_g[ iS35{56}, iS36{1024}, iS259{1024} ] float
T14_g[ iS41{56}, iS42{1024}, bS43{1 ex 1024} ] float
T15_g[ iS44{56}, iS45{1024}, iS264{1024} ] float

These tensors can be calculated pointwisely from other inputs in segmentation-2. In other words, they can be re-calculated in segmentation-2 instead of being written out in segmentation-1 and read back in segmentation-2.

Hi @parthmannan can you double check the fix and see if further optimizaiton is needed? Thanks!

liqiangxl commented 4 months ago

Another possible optimizaiton is revising getVectorizationFactor, currently it returns min(vect factor of vectorizable_inputs_outputs). Due to the 3 inter-segment fp32 tensors, the max vectorization factor is 4 for this case. The scheduler should be able to use vect = 8 and reduce to 4 for the 3 inter-segment fp32 tensors.

liqiangxl commented 4 months ago

@csarofeen torch.compile is using 5 kernels (see doc about these kernels). Mathematically, backward has:

F1: Outer reduction [I, R, I] → Iter = 56K, Redu = 1024
F2: Inner persistent[I, I, R] → Iter = 56K, Redu = 1024
F3: Outer reduction [R, R, I] → Iter = 56K, Redu = 1024

Triton used 5 kernels for backward, 133 + 127 +164 + 27 + 28 = 479 us in total.

F1: Outer reduction: xnumel = 57344, rnumel = 1024
F2: Inner normalization: xnumel = 57344, rnumel = 1024
F3: split into 3 kernels: xnumel = 1024, rnumel = 57344 = 448 * 128
Outer reduction: xnumel = 131072, rnumel = 448 (2 reductions)
Outer reduction: xnumel = 1024, rnumel = 128 (1 reduction)
Outer reduction: xnumel = 1024, rnumel = 128 (1 reduction)

In nvFuser, F1 is an outer reduciton kernel, F2&F3 are merged into an innerOuter persistent kernel. The outer reduction of F1 is slow compared with torch.compile due to 3 additional output tensors and vectorization factor is limited to 4.

nvFuser used 2 kernels for backward, 268 (outer redu)  + 325 (inner outer) = 593 us in total
F1: Outer reduction: xnumel = 57344, rnumel = 1024
F2 & F3: InnerOuter normalization
naoyam commented 4 months ago

Interesting results. Thanks.

I guess the impact to the first kernel is particularly large as it's the first kernel, so there's no L2 caching for the input read. The limitation of the vectorization factor has been known for a while now, it may be the time to work on it.

For the second segment, I understand that our combined scheduler may not always have much benefit over the segmented approach, but still it's a little disappointing that it takes about the same time as the four-segment approach. Is our performance within an expected range? Is there any room for (quick) improvement?

liqiangxl commented 4 months ago

For the second segment, I understand that our combined scheduler may not always have much benefit over the segmented approach, but still it's a little disappointing that it takes about the same time as the four-segment approach. Is our performance within an expected range? Is there any room for (quick) improvement?

I'll do some checks. This case (outer dim = 56K, inner dim = 1024) is not covered in our current benchmarks (largest outer dim is 16K), so maybe can have some further improvements.

kevinstephano commented 3 weeks ago

Closing as this was mostly addressed.