NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
238 stars 43 forks source link

reduce inter-segment I/O using rematerialization #2473

Open liqiangxl opened 1 week ago

liqiangxl commented 1 week ago

Motivation: Seems like we need an IO-aware segmenter to reduce number of tensors between different segments. A real example is from #2146, where the tensor is segmented into 2 kernels with 3 inter-segment tensors.

T12_g[ iS35{56}, iS36{1024}, iS259{1024} ] float
T14_g[ iS41{56}, iS42{1024}, bS43{1 ex 1024} ] float
T15_g[ iS44{56}, iS45{1024}, iS264{1024} ] float

These tensors can be calculated pointwisely from other inputs in segmentation-2. In other words, they can be re-calculated in segmentation-2 instead of being written out in segmentation-1 and read back in segmentation-2.

Potential fix: instead of greedily merge as many exprs as possible, may also check the influecne on IO bytes. In other words, may change the target from minimize number of segments to minimize total IO bytes.

Example from #2146

Segmented_Fusion Dump: -- fusion segments:
Segmented_Fusion{ 
groups: 
g{0, 1, 2, 3, 5, 6, 8, 9, 10, 11, 12, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 123, 124}

g{0, 9, 13, 14, 15, 16, 18, 19, 27, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 58, 60, 63, 65, 66, 67, 69, 70, 71, 72, 73, 74, 75, 76, 125, 126, 127, 128, 129, 130}

edges: 
e{ g{0, 1, 2, 3, 5, 6, 8, 9, 10, 11, 12, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 123, 124}
 -> g{0, 9, 13, 14, 15, 16, 18, 19, 27, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 58, 60, 63, 65, 66, 67, 69, 70, 71, 72, 73, 74, 75, 76, 125, 126, 127, 128, 129, 130}
(T12_g[ iS35{56}, iS36{1024}, iS259{1024} ]) }

e{ g{0, 1, 2, 3, 5, 6, 8, 9, 10, 11, 12, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 123, 124}
 -> g{0, 9, 13, 14, 15, 16, 18, 19, 27, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 58, 60, 63, 65, 66, 67, 69, 70, 71, 72, 73, 74, 75, 76, 125, 126, 127, 128, 129, 130}
(T15_g[ iS44{56}, iS45{1024}, iS264{1024} ]) }

e{ g{0, 1, 2, 3, 5, 6, 8, 9, 10, 11, 12, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 32, 33, 34, 35, 123, 124}
 -> g{0, 9, 13, 14, 15, 16, 18, 19, 27, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 58, 60, 63, 65, 66, 67, 69, 70, 71, 72, 73, 74, 75, 76, 125, 126, 127, 128, 129, 130}
(T14_g[ iS41{56}, iS42{1024}, bS43{1 ex 1024} ]) }

group details:
g{(reduction)
inputs: 
T0_g[ bS0{1 ex i0}, bS1{1 ex i1}, iS265{1024} ] __bfloat
T1_g[ bS3{1 ex i5}, bS4{1 ex i6}, iS268{1024} ] __bfloat
T3_g[ iS236{56}, iS237{1024}, iS238{1024} ] __bfloat
T4_g[ iS255{56}, iS256{1024} ] float
T5_g[ iS260{56}, iS261{1024}, bS16{1} ] float
T6_g[ iS249{56}, iS250{1024}, iS251{1024} ] __bfloat
outputs: 
T12_g[ iS35{56}, iS36{1024}, iS259{1024} ] float
T14_g[ iS41{56}, iS42{1024}, bS43{1 ex 1024} ] float
T15_g[ iS44{56}, iS45{1024}, iS264{1024} ] float
T33_g[ iS96{56}, iS97{1024} ] __bfloat
T42_g[ iS120{56}, iS121{1024} ] __bfloat

T7_g[ iS252{56}, iS253{1024}, iS254{1024} ]
   = __bfloat2float(T6_g[ iS249{56}, iS250{1024}, iS251{1024} ]);
(0)
T26_g[ iS239{56}, iS240{1024}, iS241{1024} ]
   = __bfloat2float(T3_g[ iS236{56}, iS237{1024}, iS238{1024} ]);
(19)
T18_g[ bS53{1 ex i5}, bS54{1 ex i6}, iS269{1024} ]
   = __bfloat2float(T1_g[ bS3{1 ex i5}, bS4{1 ex i6}, iS268{1024} ]);
(11)
T16_g[ bS47{1 ex i0}, bS48{1 ex i1}, iS266{1024} ]
   = __bfloat2float(T0_g[ bS0{1 ex i0}, bS1{1 ex i1}, iS265{1024} ]);
(9)
T8_l[ iS257{56}, iS258{1024}, bS25{1} ]
   = broadcast( T4_g[ iS255{56}, iS256{1024} ] )
(1)
T9_g[ iS26{56}, iS27{1024}, bS28{1} ]
   = Set( T8_l[ iS257{56}, iS258{1024}, bS25{1} ], cache_op=Streaming )
(2)
T10_g[ iS29{56}, iS30{1024}, bS31{1} ]
   = Set( T9_g[ iS26{56}, iS27{1024}, bS28{1} ], cache_op=Streaming )
(3)
T11_l[ iS32{56}, iS33{1024}, bS34{1 ex 1024} ] = expand( T10_g[ iS29{56}, iS30{1024}, bS31{1} ], {56, 1024, 1024} )
(123)
T12_g[ iS35{56}, iS36{1024}, iS259{1024} ]
   = T7_g[ iS252{56}, iS253{1024}, iS254{1024} ]
   - T11_l[ iS32{56}, iS33{1024}, bS34{1 ex 1024} ];
(5)
T27_g[ iS242{56}, rS243{1024}, iS244{1024} ]
   = reduction( T26_g[ iS239{56}, iS240{1024}, iS241{1024} ], op = add, initial value = float(0), allreduce = false )
(20)
T28_g[ iS245{56}, iS246{1024} ]
   = __float2bfloat(T27_g[ iS242{56}, rS243{1024}, iS244{1024} ]);
(21)
T29_g[ iS247{56}, bS86{1}, iS248{1024} ]
   = broadcast( T28_g[ iS245{56}, iS246{1024} ] )
(22)
T30_l[ iS88{56}, bS89{1}, iS90{1024} ]
   = Set( T29_g[ iS247{56}, bS86{1}, iS248{1024} ], cache_op=Streaming )
(23)
T31_g[ iS91{56}, bS92{1}, iS93{1024} ]
   = __bfloat2float(T30_l[ iS88{56}, bS89{1}, iS90{1024} ]);
(24)
T32_g[ iS94{56}, iS95{1024} ]
   = squeeze( T31_g[ iS91{56}, bS92{1}, iS93{1024} ] )
(25)
T33_g[ iS96{56}, iS97{1024} ]
   = __float2bfloat(T32_g[ iS94{56}, iS95{1024} ]);
(26)
T13_l[ iS262{56}, iS263{1024}, bS40{1} ]
   = Set( T5_g[ iS260{56}, iS261{1024}, bS16{1} ], cache_op=Streaming )
(6)
T14_g[ iS41{56}, iS42{1024}, bS43{1 ex 1024} ] = expand( T13_l[ iS262{56}, iS263{1024}, bS40{1} ], {56, 1024, 1024} )
(124)
T15_g[ iS44{56}, iS45{1024}, iS264{1024} ]
   = T12_g[ iS35{56}, iS36{1024}, iS259{1024} ]
   * T14_g[ iS41{56}, iS42{1024}, bS43{1 ex 1024} ];
(8)
T17_g[ iS50{56}, iS51{1024}, iS267{1024} ]
   = T15_g[ iS44{56}, iS45{1024}, iS264{1024} ]
   * T16_g[ bS47{1 ex i0}, bS48{1 ex i1}, iS266{1024} ];
(10)
T19_g[ iS56{56}, iS57{1024}, iS270{1024} ]
   = T17_g[ iS50{56}, iS51{1024}, iS267{1024} ]
   + T18_g[ bS53{1 ex i5}, bS54{1 ex i6}, iS269{1024} ];
(12)
T35_l[ iS101{56}, iS102{1024}, iS271{1024} ]
   = T19_g[ iS56{56}, iS57{1024}, iS270{1024} ]
   * T26_g[ iS239{56}, iS240{1024}, iS241{1024} ];
(28)
T36_g[ iS104{56}, rS105{1024}, iS272{1024} ]
   = reduction( T35_l[ iS101{56}, iS102{1024}, iS271{1024} ], op = add, initial value = float(0), allreduce = false )
(29)
T37_g[ iS107{56}, iS273{1024} ]
   = __float2bfloat(T36_g[ iS104{56}, rS105{1024}, iS272{1024} ]);
(30)
T38_l[ iS109{56}, bS110{1}, iS274{1024} ]
   = broadcast( T37_g[ iS107{56}, iS273{1024} ] )
(31)
T39_g[ iS112{56}, bS113{1}, iS114{1024} ]
   = Set( T38_l[ iS109{56}, bS110{1}, iS274{1024} ], cache_op=Streaming )
(32)
T40_g[ iS115{56}, bS116{1}, iS117{1024} ]
   = __bfloat2float(T39_g[ iS112{56}, bS113{1}, iS114{1024} ]);
(33)
T41_g[ iS118{56}, iS119{1024} ]
   = squeeze( T40_g[ iS115{56}, bS116{1}, iS117{1024} ] )
(34)
T42_g[ iS120{56}, iS121{1024} ]
   = __float2bfloat(T41_g[ iS118{56}, iS119{1024} ]);
(35)
}

g{(inner_outer_persistent)
inputs: 
T0_g[ bS0{1 ex i0}, bS1{1 ex i1}, iS265{1024} ] __bfloat
T2_g[ iS275{56}, bS7{1}, iS276{1024} ] __bfloat
T3_g[ iS236{56}, iS237{1024}, iS238{1024} ] __bfloat
T4_g[ iS255{56}, iS256{1024} ] float
T5_g[ iS260{56}, iS261{1024}, bS16{1} ] float
T6_g[ iS249{56}, iS250{1024}, iS251{1024} ] __bfloat
T12_g[ iS35{56}, iS36{1024}, iS259{1024} ] float
T14_g[ iS41{56}, iS42{1024}, bS43{1 ex 1024} ] float
T15_g[ iS44{56}, iS45{1024}, iS264{1024} ] float
outputs: 
T44_g[ iS125{1024} ] __bfloat
T48_g[ iS135{1024} ] __bfloat
T82_g[ iS233{56}, iS234{1024}, iS235{1024} ] __bfloat

T7_g[ iS252{56}, iS253{1024}, iS254{1024} ]
   = __bfloat2float(T6_g[ iS249{56}, iS250{1024}, iS251{1024} ]);
(0)
T26_g[ iS239{56}, iS240{1024}, iS241{1024} ]
   = __bfloat2float(T3_g[ iS236{56}, iS237{1024}, iS238{1024} ]);
(19)
T20_g[ iS277{56}, bS60{1}, iS278{1024} ]
   = __bfloat2float(T2_g[ iS275{56}, bS7{1}, iS276{1024} ]);
(13)
T16_g[ bS47{1 ex i0}, bS48{1 ex i1}, iS266{1024} ]
   = __bfloat2float(T0_g[ bS0{1 ex i0}, bS1{1 ex i1}, iS265{1024} ]);
(9)
T21_l[ iS279{56}, bS63{1}, iS280{1024} ]
   = double(1)
   + T20_g[ iS277{56}, bS60{1}, iS278{1024} ];
(14)
T22_g[ iS281{56}, bS66{1}, iS282{1024} ]
   = __float2bfloat(T21_l[ iS279{56}, bS63{1}, iS280{1024} ]);
(15)
T23_g[ iS283{56}, bS69{1}, iS284{1024} ]
   = Set( T22_g[ iS281{56}, bS66{1}, iS282{1024} ], cache_op=Streaming )
(16)
T24_l[ iS71{56}, bS72{1 ex 1024}, iS73{1024} ] = expand( T23_g[ iS283{56}, bS69{1}, iS284{1024} ], {56, 1024, 1024} )
(125)
T25_g[ iS74{56}, bS75{1 ex 1024}, iS76{1024} ]
   = __bfloat2float(T24_l[ iS71{56}, bS72{1 ex 1024}, iS73{1024} ]);
(18)
T72_l[ iS290{56}, iS291{1024}, bS205{1} ]
   = broadcast( T4_g[ iS255{56}, iS256{1024} ] )
(65)
T73_g[ iS206{56}, iS207{1024}, bS208{1} ]
   = Set( T72_l[ iS290{56}, iS291{1024}, bS205{1} ], cache_op=Streaming )
(66)
T74_g[ iS209{56}, iS210{1024}, bS211{1} ]
   = Set( T73_g[ iS206{56}, iS207{1024}, bS208{1} ], cache_op=Streaming )
(67)
T75_l[ iS212{56}, iS213{1024}, bS214{1 ex 1024} ] = expand( T74_g[ iS209{56}, iS210{1024}, bS211{1} ], {56, 1024, 1024} )
(128)
T77_g[ iS218{56}, iS219{1024}, iS292{1024} ]
   = T7_g[ iS252{56}, iS253{1024}, iS254{1024} ]
   - T75_l[ iS212{56}, iS213{1024}, bS214{1 ex 1024} ];
(70)
T59_g[ iS288{56}, iS289{1024}, bS168{1} ]
   = pow(T5_g[ iS260{56}, iS261{1024}, bS16{1} ]
  , double(3));
(52)
T34_g[ iS98{56}, iS285{1024}, iS100{1024} ]
   = T25_g[ iS74{56}, bS75{1 ex 1024}, iS76{1024} ]
   * T26_g[ iS239{56}, iS240{1024}, iS241{1024} ];
(27)
T46_l[ iS129{56}, iS130{1024}, iS131{1024} ]
   = T15_g[ iS44{56}, iS45{1024}, iS264{1024} ]
   * T34_g[ iS98{56}, iS285{1024}, iS100{1024} ];
(39)
T47_g[ rS132{56}, rS133{1024}, iS134{1024} ]
   = reduction( T46_l[ iS129{56}, iS130{1024}, iS131{1024} ], op = add, initial value = float(0), allreduce = false )
(40)
T48_g[ iS135{1024} ]
   = __float2bfloat(T47_g[ rS132{56}, rS133{1024}, iS134{1024} ]);
(41)
T43_g[ rS122{56}, rS286{1024}, iS124{1024} ]
   = reduction( T34_g[ iS98{56}, iS285{1024}, iS100{1024} ], op = add, initial value = float(0), allreduce = false )
(36)
T44_g[ iS125{1024} ]
   = __float2bfloat(T43_g[ rS122{56}, rS286{1024}, iS124{1024} ]);
(37)
T45_g[ iS126{56}, iS287{1024}, iS128{1024} ]
   = T16_g[ bS47{1 ex i0}, bS48{1 ex i1}, iS266{1024} ]
   * T34_g[ iS98{56}, iS285{1024}, iS100{1024} ];
(38)
T50_l[ iS139{56}, iS140{1024}, iS141{1024} ]
   = T12_g[ iS35{56}, iS36{1024}, iS259{1024} ]
   * T45_g[ iS126{56}, iS287{1024}, iS128{1024} ];
(43)
T51_g[ iS142{56}, iS143{1024}, rS144{1024} ]
   = reduction( T50_l[ iS139{56}, iS140{1024}, iS141{1024} ], op = add, initial value = float(0), allreduce = false )
(44)
T49_g[ iS136{56}, iS137{1024}, iS138{1024} ]
   = T14_g[ iS41{56}, iS42{1024}, bS43{1 ex 1024} ]
   * T45_g[ iS126{56}, iS287{1024}, iS128{1024} ];
(42)
T54_l[ iS151{56}, iS152{1024}, iS153{1024} ]
   = -T49_g[ iS136{56}, iS137{1024}, iS138{1024} ];
(47)
T55_g[ iS154{56}, iS155{1024}, rS156{1024} ]
   = reduction( T54_l[ iS151{56}, iS152{1024}, iS153{1024} ], op = add, initial value = float(0), allreduce = false )
(48)
T56_g[ iS157{56}, iS158{1024}, bS159{1} ]
   = broadcast( T55_g[ iS154{56}, iS155{1024}, rS156{1024} ] )
(49)
T52_g[ iS145{56}, iS146{1024}, bS147{1} ]
   = broadcast( T51_g[ iS142{56}, iS143{1024}, rS144{1024} ] )
(45)
T53_l[ iS148{56}, iS149{1024}, bS150{1} ]
   = Set( T52_g[ iS145{56}, iS146{1024}, bS147{1} ], cache_op=Streaming )
(46)
T58_g[ iS163{56}, iS164{1024}, bS165{1} ]
   = double(-0.5)
   * T53_l[ iS148{56}, iS149{1024}, bS150{1} ];
(51)
T60_g[ iS169{56}, iS170{1024}, bS171{1} ]
   = T58_g[ iS163{56}, iS164{1024}, bS165{1} ]
   * T59_g[ iS288{56}, iS289{1024}, bS168{1} ];
(53)
T69_l[ iS194{56}, iS195{1024}, bS196{1} ]
   = Set( T60_g[ iS169{56}, iS170{1024}, bS171{1} ], cache_op=Streaming )
(129)
T70_g[ iS197{56}, iS198{1024}, bS199{1} ]
   = Set( T69_l[ iS194{56}, iS195{1024}, bS196{1} ], cache_op=Streaming )
(63)
T57_l[ iS160{56}, iS161{1024}, bS162{1} ]
   = Set( T56_g[ iS157{56}, iS158{1024}, bS159{1} ], cache_op=Streaming )
(50)
T64_g[ iS179{56}, iS180{1024}, bS181{1} ]
   = Set( T57_l[ iS160{56}, iS161{1024}, bS162{1} ], cache_op=Streaming )
(130)
T65_g[ iS182{56}, iS183{1024}, bS184{1} ]
   = Set( T64_g[ iS179{56}, iS180{1024}, bS181{1} ], cache_op=Streaming )
(58)
T66_l[ iS185{56}, iS186{1024}, bS187{1 ex 1024} ] = expand( T65_g[ iS182{56}, iS183{1024}, bS184{1} ], {56, 1024, 1024} )
(126)
T67_g[ iS188{56}, iS189{1024}, bS190{1 ex 1024} ]
   = double(0.00097656200000000005)
   * T66_l[ iS185{56}, iS186{1024}, bS187{1 ex 1024} ];
(60)
T71_g[ iS200{56}, iS201{1024}, bS202{1 ex 1024} ] = expand( T70_g[ iS197{56}, iS198{1024}, bS199{1} ], {56, 1024, 1024} )
(127)
T76_l[ iS215{56}, iS216{1024}, bS217{1 ex 1024} ]
   = double(2)
   * T71_g[ iS200{56}, iS201{1024}, bS202{1 ex 1024} ];
(69)
T78_g[ iS221{56}, iS222{1024}, iS293{1024} ]
   = T76_l[ iS215{56}, iS216{1024}, bS217{1 ex 1024} ]
   * T77_g[ iS218{56}, iS219{1024}, iS292{1024} ];
(71)
d371 = reciprocal(double(1024));
(72)
T79_g[ iS224{56}, iS225{1024}, iS294{1024} ]
   = T78_g[ iS221{56}, iS222{1024}, iS293{1024} ]
   * d371;
(73)
T80_l[ iS227{56}, iS228{1024}, iS295{1024} ]
   = T67_g[ iS188{56}, iS189{1024}, bS190{1 ex 1024} ]
   + T79_g[ iS224{56}, iS225{1024}, iS294{1024} ];
(74)
T81_g[ iS230{56}, iS231{1024}, iS232{1024} ]
   = T49_g[ iS136{56}, iS137{1024}, iS138{1024} ]
   + T80_l[ iS227{56}, iS228{1024}, iS295{1024} ];
(75)
T82_g[ iS233{56}, iS234{1024}, iS235{1024} ]
   = __float2bfloat(T81_g[ iS230{56}, iS231{1024}, iS232{1024} ]);
(76)
}

} //Segmented_Fusion

Reproduce: NVFUSER_DUMP=segmented_fusion python v0_2146.py 2>&1 |tee 1.log

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[None, None, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[None, None, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.define_tensor(shape=[-1, 1, -1], contiguity=[True, None, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T3 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T4 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T5 = fd.define_tensor(shape=[-1, -1, 1], contiguity=[True, True, None], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T6 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T7 = fd.ops.cast(T6, dtype=DataType.Float)
    S8 = fd.define_scalar(56, dtype=DataType.Int)
    S9 = fd.define_scalar(1024, dtype=DataType.Int)
    S10 = fd.define_scalar(1, dtype=DataType.Int)
    V11 = fd.define_vector([S8, S9, S10], dtype=DataType.Int)
    T12 = fd.ops.broadcast_in_dim(T4, shape=V11, broadcast_dims=[0, 1])
    S13 = fd.define_scalar(56, dtype=DataType.Int)
    S14 = fd.define_scalar(1024, dtype=DataType.Int)
    S15 = fd.define_scalar(1024, dtype=DataType.Int)
    V16 = fd.define_vector([S13, S14, S15], dtype=DataType.Int)
    T17 = fd.ops.broadcast_in_dim(T12, shape=V16, broadcast_dims=[0, 1, 2])
    T18 = fd.ops.sub(T7, T17)
    S19 = fd.define_scalar(56, dtype=DataType.Int)
    S20 = fd.define_scalar(1024, dtype=DataType.Int)
    S21 = fd.define_scalar(1024, dtype=DataType.Int)
    V22 = fd.define_vector([S19, S20, S21], dtype=DataType.Int)
    T23 = fd.ops.broadcast_in_dim(T5, shape=V22, broadcast_dims=[0, 1, 2])
    T24 = fd.ops.mul(T18, T23)
    T25 = fd.ops.cast(T0, dtype=DataType.Float)
    T26 = fd.ops.mul(T24, T25)
    T27 = fd.ops.cast(T1, dtype=DataType.Float)
    T28 = fd.ops.add(T26, T27)
    T29 = fd.ops.cast(T2, dtype=DataType.Float)
    S30 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T31 = fd.ops.add(S30, T29)
    T32 = fd.ops.cast(T31, dtype=DataType.BFloat16)
    S33 = fd.define_scalar(56, dtype=DataType.Int)
    S34 = fd.define_scalar(1024, dtype=DataType.Int)
    S35 = fd.define_scalar(1024, dtype=DataType.Int)
    V36 = fd.define_vector([S33, S34, S35], dtype=DataType.Int)
    T37 = fd.ops.broadcast_in_dim(T32, shape=V36, broadcast_dims=[0, 1, 2])
    T38 = fd.ops.cast(T37, dtype=DataType.Float)
    T39 = fd.ops.cast(T3, dtype=DataType.Float)
    T40 = fd.ops.sum(T39, dims=[1], keepdim=False, dtype=DataType.Null)
    T41 = fd.ops.cast(T40, dtype=DataType.BFloat16)
    S42 = fd.define_scalar(56, dtype=DataType.Int)
    S43 = fd.define_scalar(1, dtype=DataType.Int)
    S44 = fd.define_scalar(1024, dtype=DataType.Int)
    V45 = fd.define_vector([S42, S43, S44], dtype=DataType.Int)
    T46 = fd.ops.broadcast_in_dim(T41, shape=V45, broadcast_dims=[0, 2])
    T47 = fd.ops.cast(T46, dtype=DataType.Float)
    T48 = fd.ops.sum(T47, dims=[1], keepdim=False, dtype=DataType.Null)
    T49 = fd.ops.cast(T48, dtype=DataType.BFloat16)
    T50 = fd.ops.mul(T38, T39)
    T51 = fd.ops.mul(T28, T39)
    T52 = fd.ops.sum(T51, dims=[1], keepdim=False, dtype=DataType.Null)
    T53 = fd.ops.cast(T52, dtype=DataType.BFloat16)
    S54 = fd.define_scalar(56, dtype=DataType.Int)
    S55 = fd.define_scalar(1, dtype=DataType.Int)
    S56 = fd.define_scalar(1024, dtype=DataType.Int)
    V57 = fd.define_vector([S54, S55, S56], dtype=DataType.Int)
    T58 = fd.ops.broadcast_in_dim(T53, shape=V57, broadcast_dims=[0, 2])
    T59 = fd.ops.cast(T58, dtype=DataType.Float)
    T60 = fd.ops.sum(T59, dims=[1], keepdim=False, dtype=DataType.Null)
    T61 = fd.ops.cast(T60, dtype=DataType.BFloat16)
    T62 = fd.ops.sum(T50, dims=[0, 1], keepdim=False, dtype=DataType.Null)
    T63 = fd.ops.cast(T62, dtype=DataType.BFloat16)
    T64 = fd.ops.mul(T25, T50)
    T65 = fd.ops.mul(T24, T50)
    T66 = fd.ops.sum(T65, dims=[0, 1], keepdim=False, dtype=DataType.Null)
    T67 = fd.ops.cast(T66, dtype=DataType.BFloat16)
    T68 = fd.ops.mul(T23, T64)
    T69 = fd.ops.mul(T18, T64)
    T70 = fd.ops.sum(T69, dims=[2], keepdim=False, dtype=DataType.Null)
    S71 = fd.define_scalar(56, dtype=DataType.Int)
    S72 = fd.define_scalar(1024, dtype=DataType.Int)
    S73 = fd.define_scalar(1, dtype=DataType.Int)
    V74 = fd.define_vector([S71, S72, S73], dtype=DataType.Int)
    T75 = fd.ops.broadcast_in_dim(T70, shape=V74, broadcast_dims=[0, 1])
    T76 = fd.ops.neg(T68)
    T77 = fd.ops.sum(T76, dims=[2], keepdim=False, dtype=DataType.Null)
    S78 = fd.define_scalar(56, dtype=DataType.Int)
    S79 = fd.define_scalar(1024, dtype=DataType.Int)
    S80 = fd.define_scalar(1, dtype=DataType.Int)
    V81 = fd.define_vector([S78, S79, S80], dtype=DataType.Int)
    T82 = fd.ops.broadcast_in_dim(T77, shape=V81, broadcast_dims=[0, 1])
    S83 = fd.define_scalar(-0.500000, dtype=DataType.Double)
    T84 = fd.ops.mul(S83, T75)
    S85 = fd.define_scalar(3.00000, dtype=DataType.Double)
    T86 = fd.ops.pow(T5, S85)
    T87 = fd.ops.mul(T84, T86)
    T88 = fd.ops.sum(T82, dims=[2], keepdim=False, dtype=DataType.Null)
    T89 = fd.ops.sum(T87, dims=[2], keepdim=False, dtype=DataType.Null)
    S90 = fd.define_scalar(56, dtype=DataType.Int)
    S91 = fd.define_scalar(1024, dtype=DataType.Int)
    S92 = fd.define_scalar(1, dtype=DataType.Int)
    V93 = fd.define_vector([S90, S91, S92], dtype=DataType.Int)
    T94 = fd.ops.broadcast_in_dim(T88, shape=V93, broadcast_dims=[0, 1])
    S95 = fd.define_scalar(56, dtype=DataType.Int)
    S96 = fd.define_scalar(1024, dtype=DataType.Int)
    S97 = fd.define_scalar(1024, dtype=DataType.Int)
    V98 = fd.define_vector([S95, S96, S97], dtype=DataType.Int)
    T99 = fd.ops.broadcast_in_dim(T94, shape=V98, broadcast_dims=[0, 1, 2])
    S100 = fd.define_scalar(0.000976562, dtype=DataType.Double)
    T101 = fd.ops.mul(S100, T99)
    S102 = fd.define_scalar(56, dtype=DataType.Int)
    S103 = fd.define_scalar(1024, dtype=DataType.Int)
    S104 = fd.define_scalar(1, dtype=DataType.Int)
    V105 = fd.define_vector([S102, S103, S104], dtype=DataType.Int)
    T106 = fd.ops.broadcast_in_dim(T89, shape=V105, broadcast_dims=[0, 1])
    S107 = fd.define_scalar(56, dtype=DataType.Int)
    S108 = fd.define_scalar(1024, dtype=DataType.Int)
    S109 = fd.define_scalar(1024, dtype=DataType.Int)
    V110 = fd.define_vector([S107, S108, S109], dtype=DataType.Int)
    T111 = fd.ops.broadcast_in_dim(T106, shape=V110, broadcast_dims=[0, 1, 2])
    S112 = fd.define_scalar(56, dtype=DataType.Int)
    S113 = fd.define_scalar(1024, dtype=DataType.Int)
    S114 = fd.define_scalar(1, dtype=DataType.Int)
    V115 = fd.define_vector([S112, S113, S114], dtype=DataType.Int)
    T116 = fd.ops.broadcast_in_dim(T4, shape=V115, broadcast_dims=[0, 1])
    S117 = fd.define_scalar(56, dtype=DataType.Int)
    S118 = fd.define_scalar(1024, dtype=DataType.Int)
    S119 = fd.define_scalar(1024, dtype=DataType.Int)
    V120 = fd.define_vector([S117, S118, S119], dtype=DataType.Int)
    T121 = fd.ops.broadcast_in_dim(T116, shape=V120, broadcast_dims=[0, 1, 2])
    S122 = fd.define_scalar(2.00000, dtype=DataType.Double)
    T123 = fd.ops.mul(S122, T111)
    T124 = fd.ops.sub(T7, T121)
    T125 = fd.ops.mul(T123, T124)
    S126 = fd.define_scalar(1024.00, dtype=DataType.Double)
    S127 = fd.ops.reciprocal(S126)
    T128 = fd.ops.mul(T125, S127)
    T129 = fd.ops.add(T101, T128)
    T130 = fd.ops.add(T68, T129)
    T131 = fd.ops.cast(T130, dtype=DataType.BFloat16)
    fd.add_output(T49)
    fd.add_output(T61)
    fd.add_output(T63)
    fd.add_output(T67)
    fd.add_output(T131)

with FusionDefinition() as fd:
    nvfuser_fusion_id1(fd)

inputs = [
    torch.randn((1024,), dtype=torch.bfloat16, device='cuda:0').as_strided((56, 1024, 1024), (0, 0, 1)),
    torch.randn((1024,), dtype=torch.bfloat16, device='cuda:0').as_strided((56, 1024, 1024), (0, 0, 1)),
    torch.randn((57344,), dtype=torch.bfloat16, device='cuda:0').as_strided((56, 1, 1024), (1024, 1024, 1)),
    torch.randn((58720256,), dtype=torch.bfloat16, device='cuda:0').as_strided((56, 1024, 1024), (1048576, 1024, 1)),
    torch.randn((57344,), dtype=torch.float32, device='cuda:0').as_strided((56, 1024), (1024, 1)),
    torch.randn((57344,), dtype=torch.float32, device='cuda:0').as_strided((56, 1024, 1), (1024, 1, 1)),
    torch.randn((58720256,), dtype=torch.bfloat16, device='cuda:0').as_strided((56, 1024, 1024), (1048576, 1024, 1)),
]
fd.execute(inputs)
wujingyue commented 1 week ago

they can be re-calculated in segmentation-2 instead of being written out in segmentation-1 and read back in segmentation-2.

When you say "re-calculate", are you proposing to compute the three tensors for segment-1 as well as segment-2? If yes, it would be something like rematerialization that trades more compute for less I/O. Otherwise, it indeed sounds like a where-to-put-boundary-in-the-DAG kind of problem that's in the domain of segmentation.

liqiangxl commented 1 week ago

It belongs to rematerialization.

wujingyue commented 1 week ago

It belongs to rematerialization.

I see -- that's a harder problem.

My gut feeling says we would want a segmentation-aware rematerialization pass. One possible implementation is to run segmentation first and then rematerialize only TensorViews across segments (because we know data copy via global mem is expensive). This reminds me of the existing rematerialization pass in Thunder. Any lessons we can learn from there? @IvanYashchuk and @jjsjann123

csarofeen commented 1 week ago

Is this a case where there isn’t a better min-cut for the segmentation so we need to rematerialize? Otherwise is it that rematerializing a tensor seems to be the most straight forward implementation? If it’s the former I would imagine having a post segmentation pass that tries to move tensors from one group to another to find a better min-cut (as long as it doesn’t change the heuristic choice) may be an effective strategy to start looking at memory cost aware segmentation.

csarofeen commented 1 week ago

Rematerialization of course could be another option (I’m not doubting that), I’m just wondering what that algorithm would look like. This is something that could be particularly valuable if we have a full forward-backward graph.

jjsjann123 commented 1 week ago

This reminds me of the existing rematerialization pass in Thunder. Any lessons we can learn from there?

I'll let @IvanYashchuk cover that question :)

As @csarofeen pointed out, find a better min-cut (as long as it doesn’t change the heuristic choice) nvfuser segments needs to comply to what scheduler can handle. There's another constraint on top of the min-cut.

liqiangxl commented 6 days ago

Is this a case where there isn’t a better min-cut for the segmentation so we need to rematerialize? Otherwise is it that rematerializing a tensor seems to be the most straight forward implementation? If it’s the former I would imagine having a post segmentation pass that tries to move tensors from one group to another to find a better min-cut (as long as it doesn’t change the heuristic choice) may be an effective strategy to start looking at memory cost aware segmentation.

In this case, there doesn't seem to have a bettter min-cut by moving tensors from one group to another. All these inter-segment tensors (12, 14, and 15) are needed to calculate the reductions in each segments. A simplified version of the fusion is as follows: image

IvanYashchuk commented 6 days ago

https://github.com/NVIDIA/Fuser/issues/2473#issuecomment-2197631460:

This reminds me of the existing rematerialization pass in Thunder. Any lessons we can learn from there?

Rematerialization in Thunder does the min-cut on the producer-consumer graph with the restriction that no node is allowed to be moved from consumer to producer. Currently, Thunder doesn't use any shape information on the tensors because I initially thought symbolic shapes would be more important in Thunder. There are only two preferences encoded as weights:

  1. Node weight is scaled by the tensor dtype byte size resulting in a preference for lower precision tensors to be saved in global memory,
  2. Producer inputs are already in global memory the corresponding node weights resulting in a preference for these inputs to be put in a cut set.

My gut feeling says we would want a segmentation-aware rematerialization pass.

However, if you decide using min-cut for this "multiway cut" is an NP-hard problem. In Thunder, we go once through each producer-consumer pair sequentially in the order of producers appearing in the trace and each min-cut computation sees the updated producers and consumers.

For Thunder, it would be useful if there was an ability to query nvFuser's FusionDefinition object for current segmentation boundaries and what intermediates would result in global tensors. Ideally, it should be possible without seeing real strided Tensor inputs. This information could be used in Thunder's rematerialization and memory usage estimation.