Closed parthmannan closed 3 weeks ago
Given Input Tensors:
x = batch-size, seq-len, hidden-units]
scale = [batch-size, hidden-units]
shift = [batch-size, hidden-units]
Some observations about the backward trace:
x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
is handled by the Reduction scheduler because of the expanded, broadcast dimension at dimension 1.Scheduler _inner_outer_persistent_ ***rejected*** because : to use combined reduction, inner reduction tensor should be [I,I,...,R,R] and outer reduction tensor should be [R,R,...,I,I]
Thanks Ryan! Is there a way of re-writing the model code/Thunder trace that would help the segmentation?
After failing on the original fusion, the segmentation runs and generates the four segmented fusions. There isn't likely anything wrong with the segmentation algorithm. It just isn't guaranteed to find the optimal sub-fusion partition.
This is quite unfortunate. I suspect something similar ot https://github.com/NVIDIA/Fuser/issues/1707 is happening -- a combination of a suboptimal segment merging order and/or canSchedule not being "monotonic".
Is there a way of re-writing the model code/Thunder trace that would help the segmentation?
No, but I think a possible workaround can be to modify segmentation to first attempt to segment only at segment_set
and to write a pre-seg pass to add segment_set
to the right place for this particular pattern. Along this line, I'm also open to expose segment_set
to Thunder, but it feels like lots of work to expose a knob that the user doesn't know how to use.
If we do the following two changes we can get 2 kernels
hasNonNormalizePostReductionBCast
in inner_outer_persistent scheduler.The total kernel time is reduced from 1.02 ms to 0.63 ms. Before change: 4 kernels: 0.342+0.311+0.209+0.157 = 1.019 ms (1) outer reduction (2) outer reduction (3) inner reduction (4) pointwise
After change: two kernels: 0.315 + 0.317 = 0.632 ms (1) outer reduction (2) innter_outer_persistent
Still lower than target, but we can further improve through the following two works:
No, but I think a possible workaround can be to modify segmentation to first attempt to segment only at
segment_set
and to write a pre-seg pass to addsegment_set
to the right place for this particular pattern.
Interesting. This pass might not work if there were not enough segment_set
s to form a full graph cut, but in such case it could just proceed with scheduling as usual. This pass could just replace the current stage where we check scheduling the complete fusion.
Thanks for the investigation, @liqiangxl !
If we do the following two changes we can get 2 kernels
- (1) add a inner_outer_reduction scheduler. I only added a placeholder, it is used to make the segmenter believe it can schedule fusion with inner reduction and outer reduction. Then, it will further merge more exprs to create a fusion using inner_outer_persistent scheduler.
- (2) remove restirction of
hasNonNormalizePostReductionBCast
in inner_outer_persistent scheduler.
I suspect (2) might be sufficient. When I read the segmenter logging,
$ git checkout wjy/layernorm
$ NVFUSER_DUMP=segmenter_logging python nvfuser_repro.py
the following attempt to merge inner_reduction and pointwise into inner_persistent failed due to "unsupported post reduction normalization". I suspect if we remove that restriction, we can fuse inner_reduction and pointwise into inner_persistent, and then merge that and outer_reduction into inner_outer_persistent. Does that make sense to you?
**Segmenter** Considering fusion:
T21_l[ iS279{56}, bS63{1}, iS280{1024} ]
= double(1)
+ T20_g[ iS277{56}, bS60{1}, iS278{1024} ];
T22_g[ iS281{56}, bS66{1}, iS282{1024} ]
= __float2bfloat(T21_l[ iS279{56}, bS63{1}, iS280{1024} ]);
T23_g[ iS283{56}, bS69{1}, iS284{1024} ]
= Set( T22_g[ iS281{56}, bS66{1}, iS282{1024} ], cache_op=Streaming )
T24_l[ iS71{56}, bS72{1 ex 1024}, iS73{1024} ] = expand( T23_g[ iS283{56}, bS69{1}, iS284{1024} ], {56, 1024, 1024} )
T25_g[ iS74{56}, bS75{1 ex 1024}, iS76{1024} ]
= __bfloat2float(T24_l[ iS71{56}, bS72{1 ex 1024}, iS73{1024} ]);
T34_g[ iS98{56}, iS285{1024}, iS100{1024} ]
= T25_g[ iS74{56}, bS75{1 ex 1024}, iS76{1024} ]
* T26_g[ iS239{56}, iS240{1024}, iS241{1024} ];
T45_g[ iS126{56}, iS287{1024}, iS128{1024} ]
= T16_g[ bS47{1 ex i0}, bS48{1 ex i1}, iS266{1024} ]
* T34_g[ iS98{56}, iS285{1024}, iS100{1024} ];
T49_g[ iS136{56}, iS137{1024}, iS138{1024} ]
= T14_g[ iS41{56}, iS42{1024}, bS43{1 ex 1024} ]
* T45_g[ iS126{56}, iS287{1024}, iS128{1024} ];
T54_l[ iS151{56}, iS152{1024}, iS153{1024} ]
= -T49_g[ iS136{56}, iS137{1024}, iS138{1024} ];
T55_g[ iS154{56}, iS155{1024}, rS156{1024} ]
= reduction( T54_l[ iS151{56}, iS152{1024}, iS153{1024} ], op = add, initial value = float(0), allreduce = false )
T56_g[ iS157{56}, iS158{1024}, bS159{1} ]
= broadcast( T55_g[ iS154{56}, iS155{1024}, rS156{1024} ] )
T57_l[ iS160{56}, iS161{1024}, bS162{1} ]
= Set( T56_g[ iS157{56}, iS158{1024}, bS159{1} ], cache_op=Streaming )
T61_g[ iS172{56}, iS173{1024} ]
= squeeze( T57_l[ iS160{56}, iS161{1024}, bS162{1} ] )
T63_g[ iS176{56}, iS177{1024}, bS178{1} ]
= broadcast( T61_g[ iS172{56}, iS173{1024} ] )
T64_l[ iS179{56}, iS180{1024}, bS181{1} ]
= Set( T63_g[ iS176{56}, iS177{1024}, bS178{1} ], cache_op=Streaming )
T65_g[ iS182{56}, iS183{1024}, bS184{1} ]
= Set( T64_l[ iS179{56}, iS180{1024}, bS181{1} ], cache_op=Streaming )
T66_l[ iS185{56}, iS186{1024}, bS187{1 ex 1024} ] = expand( T65_g[ iS182{56}, iS183{1024}, bS184{1} ], {56, 1024, 1024} )
T67_g[ iS188{56}, iS189{1024}, bS190{1 ex 1024} ]
= double(0.00097656200000000005)
* T66_l[ iS185{56}, iS186{1024}, bS187{1 ex 1024} ];
T50_l[ iS139{56}, iS140{1024}, iS141{1024} ]
= T12_g[ iS35{56}, iS36{1024}, iS259{1024} ]
* T45_g[ iS126{56}, iS287{1024}, iS128{1024} ];
T51_g[ iS142{56}, iS143{1024}, rS144{1024} ]
= reduction( T50_l[ iS139{56}, iS140{1024}, iS141{1024} ], op = add, initial value = float(0), allreduce = false )
T52_g[ iS145{56}, iS146{1024}, bS147{1} ]
= broadcast( T51_g[ iS142{56}, iS143{1024}, rS144{1024} ] )
T53_l[ iS148{56}, iS149{1024}, bS150{1} ]
= Set( T52_g[ iS145{56}, iS146{1024}, bS147{1} ], cache_op=Streaming )
T58_g[ iS163{56}, iS164{1024}, bS165{1} ]
= double(-0.5)
* T53_l[ iS148{56}, iS149{1024}, bS150{1} ];
T59_g[ iS288{56}, iS289{1024}, bS168{1} ]
= pow(T5_g[ iS260{56}, iS261{1024}, bS16{1} ]
, double(3));
T60_g[ iS169{56}, iS170{1024}, bS171{1} ]
= T58_g[ iS163{56}, iS164{1024}, bS165{1} ]
* T59_g[ iS288{56}, iS289{1024}, bS168{1} ];
T62_l[ iS174{56}, iS175{1024} ]
= squeeze( T60_g[ iS169{56}, iS170{1024}, bS171{1} ] )
T68_g[ iS191{56}, iS192{1024}, bS193{1} ]
= broadcast( T62_l[ iS174{56}, iS175{1024} ] )
T69_g[ iS194{56}, iS195{1024}, bS196{1} ]
= Set( T68_g[ iS191{56}, iS192{1024}, bS193{1} ], cache_op=Streaming )
T70_l[ iS197{56}, iS198{1024}, bS199{1} ]
= Set( T69_g[ iS194{56}, iS195{1024}, bS196{1} ], cache_op=Streaming )
T71_g[ iS200{56}, iS201{1024}, bS202{1 ex 1024} ] = expand( T70_l[ iS197{56}, iS198{1024}, bS199{1} ], {56, 1024, 1024} )
T76_g[ iS215{56}, iS216{1024}, bS217{1 ex 1024} ]
= double(2)
* T71_g[ iS200{56}, iS201{1024}, bS202{1 ex 1024} ];
T72_l[ iS290{56}, iS291{1024}, bS205{1} ]
= broadcast( T4_g[ iS255{56}, iS256{1024} ] )
T73_g[ iS206{56}, iS207{1024}, bS208{1} ]
= Set( T72_l[ iS290{56}, iS291{1024}, bS205{1} ], cache_op=Streaming )
T74_g[ iS209{56}, iS210{1024}, bS211{1} ]
= Set( T73_g[ iS206{56}, iS207{1024}, bS208{1} ], cache_op=Streaming )
T75_l[ iS212{56}, iS213{1024}, bS214{1 ex 1024} ] = expand( T74_g[ iS209{56}, iS210{1024}, bS211{1} ], {56, 1024, 1024} )
T77_g[ iS218{56}, iS219{1024}, iS292{1024} ]
= T7_g[ iS252{56}, iS253{1024}, iS254{1024} ]
- T75_l[ iS212{56}, iS213{1024}, bS214{1 ex 1024} ];
T78_l[ iS221{56}, iS222{1024}, iS293{1024} ]
= T76_g[ iS215{56}, iS216{1024}, bS217{1 ex 1024} ]
* T77_g[ iS218{56}, iS219{1024}, iS292{1024} ];
d371 = reciprocal(double(1024));
T79_g[ iS224{56}, iS225{1024}, iS294{1024} ]
= T78_l[ iS221{56}, iS222{1024}, iS293{1024} ]
* d371;
T80_g[ iS227{56}, iS228{1024}, iS295{1024} ]
= T67_g[ iS188{56}, iS189{1024}, bS190{1 ex 1024} ]
+ T79_g[ iS224{56}, iS225{1024}, iS294{1024} ];
T81_g[ iS230{56}, iS231{1024}, iS232{1024} ]
= T49_g[ iS136{56}, iS137{1024}, iS138{1024} ]
+ T80_g[ iS227{56}, iS228{1024}, iS295{1024} ];
T82_g[ iS233{56}, iS234{1024}, iS235{1024} ]
= __float2bfloat(T81_g[ iS230{56}, iS231{1024}, iS232{1024} ]);
Scheduler _no_op_ ***rejected*** because : reduction of non-zero elements is not supported
Scheduler _matmul_ ***rejected*** because : Matmul scheduler supports fusions only with a single mma opor supports a mul-sum pair which can be replaced with a mma op
Scheduler _reduction_ ***rejected*** because : need persistent buffers that reduction scheduler doesn't handle
Scheduler _transpose_ ***rejected*** because : no support for reduction ops
Scheduler _pointwise_ ***rejected*** because : no support for reduction ops
Scheduler _inner_persistent_ ***rejected*** because : unsupported post reduction normalization
Scheduler _outer_persistent_ ***rejected*** because : schedule_heuristic doesn't match with reduction type `inner_persistent`.
Scheduler _inner_outer_persistent_ ***rejected*** because : heuristicType() doesn't match with reduction type `inner_persistent`.
I'd love to know what hasNonNormalizePostReductionBCast does exactly. Some broadcasts seem really trivial and should be gotten rid of before segmentation. For example,
T61_g[ iS172{56}, iS173{1024} ]
= squeeze( T57_l[ iS160{56}, iS161{1024}, bS162{1} ] )
T63_g[ iS176{56}, iS177{1024}, bS178{1} ]
= broadcast( T61_g[ iS172{56}, iS173{1024} ] )
and
T62_l[ iS174{56}, iS175{1024} ]
= squeeze( T60_g[ iS169{56}, iS170{1024}, bS171{1} ] )
T68_g[ iS191{56}, iS192{1024}, bS193{1} ]
= broadcast( T62_l[ iS174{56}, iS175{1024} ] )
That leaves only three broadcasts:
T72 isn't reachable from any reduction so shouldn't affect hasNonNormalizePostReductionBCast. T56 and T52 look like the very normal pattern of normalization after reduction.
As the function name implies, hasNonNormalizePostReductionBCast
looks for post reduction resolved broadcast domain. Then it checks whether it follows the normalization pattern by propagate backwards. Track the id's that were resolved and make sure there's a mapping to a TensorView before a reduction. see issue https://github.com/NVIDIA/Fuser/issues/2046
we had a discussion about removing these broadcast-squeeze
and squeeze-broadcast
ops. will ping you in teams.
- hasNonNormalizePostReductionBCast
Thanks for the investigation, @liqiangxl !
If we do the following two changes we can get 2 kernels
- (1) add a inner_outer_reduction scheduler. I only added a placeholder, it is used to make the segmenter believe it can schedule fusion with inner reduction and outer reduction. Then, it will further merge more exprs to create a fusion using inner_outer_persistent scheduler.
- (2) remove restirction of
hasNonNormalizePostReductionBCast
in inner_outer_persistent scheduler.I suspect (2) might be sufficient. When I read the segmenter logging,
$ git checkout wjy/layernorm $ NVFUSER_DUMP=segmenter_logging python nvfuser_repro.py
the following attempt to merge inner_reduction and pointwise into inner_persistent failed due to "unsupported post reduction normalization". I suspect if we remove that restriction, we can fuse inner_reduction and pointwise into inner_persistent, and then merge that and outer_reduction into inner_outer_persistent. Does that make sense to you?
Yes, that should also work. If we can remove broadcast-squeeze
and squeeze-broadcast
ops, we don't even need to change hasNonNormalizePostReductionBCast
.
@liqiangxl How about first trying to remove squeeze-broadcast before segmentation? I think it's an optimization that's useful anyway given Thunder uses broadcast_in_dims extensively. Then, we can check how things improve and decide next steps.
@liqiangxl How about first trying to remove squeeze-broadcast before segmentation? I think it's an optimization that's useful anyway given Thunder uses broadcast_in_dims extensively. Then, we can check how things improve and decide next steps.
Sounds good to me!
I added a repro file to your branch wjy/layernorm
, after manually remove squeeze-broadcast
, it can generate two kernels.
Fixes are merged in the main branch, current performance measured on H100: two kernels: 0.265ms (outer reduction) + 0.329ms (innter_outer_persistent) = 0.594 ms Previous performance: 1.0 ms
Further optimization opportunities: There are 3 inter-segment tensors, they are output of segmentation-1 and input of segmentation-2.
T12_g[ iS35{56}, iS36{1024}, iS259{1024} ] float
T14_g[ iS41{56}, iS42{1024}, bS43{1 ex 1024} ] float
T15_g[ iS44{56}, iS45{1024}, iS264{1024} ] float
These tensors can be calculated pointwisely from other inputs in segmentation-2. In other words, they can be re-calculated in segmentation-2 instead of being written out in segmentation-1 and read back in segmentation-2.
Hi @parthmannan can you double check the fix and see if further optimizaiton is needed? Thanks!
Another possible optimizaiton is revising getVectorizationFactor
, currently it returns min(vect factor of vectorizable_inputs_outputs)
. Due to the 3 inter-segment fp32 tensors, the max vectorization factor is 4 for this case. The scheduler should be able to use vect = 8
and reduce to 4
for the 3 inter-segment fp32 tensors.
@csarofeen torch.compile is using 5 kernels (see doc about these kernels). Mathematically, backward has:
F1: Outer reduction [I, R, I] → Iter = 56K, Redu = 1024
F2: Inner persistent[I, I, R] → Iter = 56K, Redu = 1024
F3: Outer reduction [R, R, I] → Iter = 56K, Redu = 1024
Triton used 5 kernels for backward, 133 + 127 +164 + 27 + 28 = 479 us in total.
F1: Outer reduction: xnumel = 57344, rnumel = 1024
F2: Inner normalization: xnumel = 57344, rnumel = 1024
F3: split into 3 kernels: xnumel = 1024, rnumel = 57344 = 448 * 128
Outer reduction: xnumel = 131072, rnumel = 448 (2 reductions)
Outer reduction: xnumel = 1024, rnumel = 128 (1 reduction)
Outer reduction: xnumel = 1024, rnumel = 128 (1 reduction)
In nvFuser, F1 is an outer reduciton kernel, F2&F3 are merged into an innerOuter persistent kernel. The outer reduction of F1 is slow compared with torch.compile due to 3 additional output tensors and vectorization factor is limited to 4.
nvFuser used 2 kernels for backward, 268 (outer redu) + 325 (inner outer) = 593 us in total
F1: Outer reduction: xnumel = 57344, rnumel = 1024
F2 & F3: InnerOuter normalization
Interesting results. Thanks.
I guess the impact to the first kernel is particularly large as it's the first kernel, so there's no L2 caching for the input read. The limitation of the vectorization factor has been known for a while now, it may be the time to work on it.
For the second segment, I understand that our combined scheduler may not always have much benefit over the segmented approach, but still it's a little disappointing that it takes about the same time as the four-segment approach. Is our performance within an expected range? Is there any room for (quick) improvement?
For the second segment, I understand that our combined scheduler may not always have much benefit over the segmented approach, but still it's a little disappointing that it takes about the same time as the four-segment approach. Is our performance within an expected range? Is there any room for (quick) improvement?
I'll do some checks. This case (outer dim = 56K, inner dim = 1024) is not covered in our current benchmarks (largest outer dim is 16K), so maybe can have some further improvements.
Closing as this was mostly addressed.
nvFuser generated code for a fusion block present in DiT has worse than expected performance. The subgraph is performing a
LayerNorm + + Mul + Add + Add
computation as shown in the code below. nvFuser code was generated using Lightning Thunder compiler for PyTorch.Below is the subgraph highlighted.
Reproducible script (requires Thunder installed)
Also attached is the log file which has the nvFuser generated code for forward and backward of this subgraph.
thunder_layernorm_scale_add_trace.log