NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
270 stars 53 forks source link

Persistent buffer selection: cache var used by each iteration in grid persistent kernel, e.g. weight in layer norm backward #2525 #402

Closed liqiangxl closed 3 weeks ago

liqiangxl commented 1 year ago

Moved from https://github.com/csarofeen/pytorch/issues/2525

In layer norm backward: For input DataType::Half, the persistent buffers are projected to three inputs (dy, x, weight), total size is 3 sizeof(half) dim1 For input DataType::Float the persistent buffers are NOT projected, they are xhat and d_xhat, the total size is 2 sizeof(float) dim1 If I enforce projection for input DataType::Float, there is a significiant speedup, e.g. for case 2048 x 10240 the time is reduced from 274 us to 207 us, for case 2048 x 1024 the time is reduced from 39 us to 36 us. The reason is because weight is shared across different rows. If we keep it persistent, we don't need to reload it in the iteration over different rows. The projected version needs more registers per thread but it doesn't reduce the occupancy ratio as the all the blocks must be active at the same time for this grid persistent kernel.

naoyam commented 1 year ago

Can you please create a repro?

liqiangxl commented 1 year ago

Can you please create a repro?

Here is the branch: https://github.com/NVIDIA/Fuser/tree/repro_persistent_projection check the newly added test case CombinedSchedulerPersistentProjection_CUDA

The results I tested on A100.

// without projection: kernel1 run in 0.18432 ms, achieved: 1092.8 GB/s
// enforce projection: kernel1 run in 0.154624 ms, achieved: 1302.68 GB/s
// to enforce projection uncomment L1322 to L1336 in normalization.cpp
liqiangxl commented 1 year ago

The weight tensor is fusion.addInput(weight);. It is T4_g in the fusion_ir_math: Generated by enforce projection. To enforce projection uncomment L1322 to L1336 in normalization.cpp

Inputs:
  T0_g[ iS204{gridDim.y}, iS205{( ceilDiv(i0, gridDim.y) )}, iS213{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, iS209{blockDim.x}, iS210{2}, iS212{1}, iS207{4} ], float
  T1_g[ iS312{gridDim.y}, iS313{( ceilDiv(i3, gridDim.y) )}, iS321{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i4, 4) ), blockDim.x) ), 2) ), 1) )}, iS317{blockDim.x}, iS318{2}, iS320{1}, iS315{4} ], float
  T2_g[ iS292{gridDim.y}, iS293{( ceilDiv(8192, gridDim.y) )}, bS301{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) ), 2) ), 1) )}, bS297{blockDim.x}, bS298{2}, bS300{1}, bS295{4} ], float
  T3_g[ iS252{gridDim.y}, iS253{( ceilDiv(8192, gridDim.y) )}, bS261{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) ), 2) ), 1) )}, bS257{blockDim.x}, bS258{2}, bS260{1}, bS255{4} ], float
  T4_g[ iS521{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i8, 4) ), blockDim.x) ), 2) ), 1) )}, iS517{blockDim.x}, iS518{2}, iS520{1}, iS515{4} ], float
  T5_g[ iS9{i9} ], float
Outputs:
  T20_g[ iblockIdx.y392{gridDim.y}, iS393{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y401{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x397{blockDim.x}, iS398{2}, iUS400{1}, iV395{4} ] ca_pos( 6 ) produce_pos( 6 ), float
  T22_g[ iblockIdx.y552{gridDim.y}, ithreadIdx.x550{blockDim.x}, iS551{( ceilDiv(( ceilDiv(( ceilDiv(i2, 2) ), blockDim.x) ), gridDim.y) )}, iV548{2} ] ca_pos( 3 ) produce_pos( 3 ), float
  T23_g[ iblockIdx.y566{gridDim.y}, ithreadIdx.x564{blockDim.x}, iS565{( ceilDiv(( ceilDiv(( ceilDiv(i2, 2) ), blockDim.x) ), gridDim.y) )}, iV562{2} ] ca_pos( 3 ) produce_pos( 3 ), float

%kernel_math {
d11 = (double)(i4);
d12 = double(1) * d11;
d47 = reciprocal(d12);
T33_l[ iblockIdx.y242{gridDim.y}, iS243{( ceilDiv(8192, gridDim.y) )}, bthreadIdx.y251{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) ), 2) ), 1) )}, bthreadIdx.x247{blockDim.x}, bS248{2}, bUS250{1}, bS245{4} ] ca_pos( 2 )
   = Set( T3_g[ iS252{gridDim.y}, iS253{( ceilDiv(8192, gridDim.y) )}, bS261{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) ), 2) ), 1) )}, bS257{blockDim.x}, bS258{2}, bS260{1}, bS255{4} ] )
T19_l[ iblockIdx.y402{gridDim.y}, iS403{( ceilDiv(8192, gridDim.y) )}, bthreadIdx.y411{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) ), 2) ), 1) )}, bthreadIdx.x407{blockDim.x}, bS408{2}, bUS410{1}, bS405{4} ] ca_pos( 2 ) produce_pos( 2 )
   = d47
   * T33_l[ iblockIdx.y242{gridDim.y}, iS243{( ceilDiv(8192, gridDim.y) )}, bthreadIdx.y251{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) ), 2) ), 1) )}, bthreadIdx.x247{blockDim.x}, bS248{2}, bUS250{1}, bS245{4} ] ca_pos( 2 );
T30_l[ iblockIdx.y194{gridDim.y}, iS195{( ceilDiv(i0, gridDim.y) )}, ithreadIdx.y203{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x199{blockDim.x}, iS200{2}, iUS202{1}, iV197{4} ] ca_pos( 2 )
   = Set( T0_g[ iS204{gridDim.y}, iS205{( ceilDiv(i0, gridDim.y) )}, iS213{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, iS209{blockDim.x}, iS210{2}, iS212{1}, iS207{4} ] )
T34_l[ ithreadIdx.y513{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i8, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x509{blockDim.x}, iS510{2}, iUS512{1}, iV507{4} ]
   = Set( T4_g[ iS521{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i8, 4) ), blockDim.x) ), 2) ), 1) )}, iS517{blockDim.x}, iS518{2}, iS520{1}, iS515{4} ] )
T24_l[ bblockIdx.y450{gridDim.y}, bS451{( ceilDiv(1, gridDim.y) )}, ithreadIdx.y459{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i8, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x455{blockDim.x}, iS456{2}, iUS458{1}, iS453{4} ] ca_pos( 7 )
   = broadcast( T34_l[ ithreadIdx.y513{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i8, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x509{blockDim.x}, iS510{2}, iUS512{1}, iV507{4} ] )
T25_l[ iblockIdx.y440{gridDim.y}, iS441{( ceilDiv(i0, gridDim.y) )}, ithreadIdx.y449{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x445{blockDim.x}, iS446{2}, iUS448{1}, iS443{4} ] ca_pos( 7 ) produce_pos( 7 )
   = T30_l[ iblockIdx.y194{gridDim.y}, iS195{( ceilDiv(i0, gridDim.y) )}, ithreadIdx.y203{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x199{blockDim.x}, iS200{2}, iUS202{1}, iV197{4} ] ca_pos( 2 )
   * T24_l[ bblockIdx.y450{gridDim.y}, bS451{( ceilDiv(1, gridDim.y) )}, ithreadIdx.y459{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i8, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x455{blockDim.x}, iS456{2}, iUS458{1}, iS453{4} ] ca_pos( 7 );
T10_l[ iblockIdx.y372{gridDim.y}, iS373{( ceilDiv(i0, gridDim.y) )}, ithreadIdx.y381{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x377{blockDim.x}, iS378{2}, iUS380{1}, iS375{4} ] ca_pos( 7 ) produce_pos( 7 )
   = d12
   * T25_l[ iblockIdx.y440{gridDim.y}, iS441{( ceilDiv(i0, gridDim.y) )}, ithreadIdx.y449{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x445{blockDim.x}, iS446{2}, iUS448{1}, iS443{4} ] ca_pos( 7 ) produce_pos( 7 );
T8_l[ bblockIdx.y184{gridDim.y}, bS185{( ceilDiv(1, gridDim.y) )}, ithreadIdx.y193{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i8, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x189{blockDim.x}, iS190{2}, iUS192{1}, iS187{4} ] ca_pos( 7 )
   = broadcast( T34_l[ ithreadIdx.y513{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i8, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x509{blockDim.x}, iS510{2}, iUS512{1}, iV507{4} ] )
T9_l[ iblockIdx.y174{gridDim.y}, iS175{( ceilDiv(i0, gridDim.y) )}, ithreadIdx.y183{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x179{blockDim.x}, iS180{2}, iUS182{1}, iS177{4} ] ca_pos( 7 ) produce_pos( 7 )
   = T30_l[ iblockIdx.y194{gridDim.y}, iS195{( ceilDiv(i0, gridDim.y) )}, ithreadIdx.y203{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x199{blockDim.x}, iS200{2}, iUS202{1}, iV197{4} ] ca_pos( 2 )
   * T8_l[ bblockIdx.y184{gridDim.y}, bS185{( ceilDiv(1, gridDim.y) )}, ithreadIdx.y193{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i8, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x189{blockDim.x}, iS190{2}, iUS192{1}, iS187{4} ] ca_pos( 7 );
T38_l[ iblockIdx.y83{gridDim.y}, iS84{( ceilDiv(i0, gridDim.y) )}, ithreadIdx.y92{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}rf, ithreadIdx.x88{blockDim.x}rf, rS89{2}rf, rUS91{1}rf, rS86{4}rf ] ca_pos( 4 ) produce_pos( 7 )
   = reduction( T9_l[ iblockIdx.y174{gridDim.y}, iS175{( ceilDiv(i0, gridDim.y) )}, ithreadIdx.y183{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x179{blockDim.x}, iS180{2}, iUS182{1}, iS177{4} ] ca_pos( 7 ) produce_pos( 7 ), op = add, initial value = double(0), allreduce = false )
T11_l[ iblockIdx.y96{gridDim.y}, iS97{( ceilDiv(i0, gridDim.y) )}, rthreadIdx.y94{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, rthreadIdx.x95{blockDim.x} ] ca_pos( 2 ) produce_pos( 4 )
   = reduction( T38_l[ iblockIdx.y83{gridDim.y}, iS84{( ceilDiv(i0, gridDim.y) )}, ithreadIdx.y92{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}rf, ithreadIdx.x88{blockDim.x}rf, rS89{2}rf, rUS91{1}rf, rS86{4}rf ] ca_pos( 4 ) produce_pos( 7 ), op = add, initial value = double(0), allreduce = false )
T12_l[ iblockIdx.y362{gridDim.y}, iS363{( ceilDiv(i0, gridDim.y) )}, bthreadIdx.y371{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) ), 2) ), 1) )}, bthreadIdx.x367{blockDim.x}, bS368{2}, bUS370{1}, bS365{4} ] ca_pos( 2 ) produce_pos( 2 )
   = broadcast( T11_l[ iblockIdx.y96{gridDim.y}, iS97{( ceilDiv(i0, gridDim.y) )}, rthreadIdx.y94{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, rthreadIdx.x95{blockDim.x} ] ca_pos( 2 ) produce_pos( 4 ) )
T17_l[ iblockIdx.y352{gridDim.y}, iS353{( ceilDiv(i0, gridDim.y) )}, ithreadIdx.y361{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x357{blockDim.x}, iS358{2}, iUS360{1}, iS355{4} ] ca_pos( 7 ) produce_pos( 7 )
   = T10_l[ iblockIdx.y372{gridDim.y}, iS373{( ceilDiv(i0, gridDim.y) )}, ithreadIdx.y381{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x377{blockDim.x}, iS378{2}, iUS380{1}, iS375{4} ] ca_pos( 7 ) produce_pos( 7 )
   - T12_l[ iblockIdx.y362{gridDim.y}, iS363{( ceilDiv(i0, gridDim.y) )}, bthreadIdx.y371{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) ), 2) ), 1) )}, bthreadIdx.x367{blockDim.x}, bS368{2}, bUS370{1}, bS365{4} ] ca_pos( 2 ) produce_pos( 2 );
T31_l[ iblockIdx.y302{gridDim.y}, iS303{( ceilDiv(i3, gridDim.y) )}, ithreadIdx.y311{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i4, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x307{blockDim.x}, iS308{2}, iUS310{1}, iV305{4} ] ca_pos( 2 )
   = Set( T1_g[ iS312{gridDim.y}, iS313{( ceilDiv(i3, gridDim.y) )}, iS321{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i4, 4) ), blockDim.x) ), 2) ), 1) )}, iS317{blockDim.x}, iS318{2}, iS320{1}, iS315{4} ] )
T32_l[ iblockIdx.y282{gridDim.y}, iS283{( ceilDiv(8192, gridDim.y) )}, bthreadIdx.y291{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) ), 2) ), 1) )}, bthreadIdx.x287{blockDim.x}, bS288{2}, bUS290{1}, bS285{4} ] ca_pos( 2 )
   = Set( T2_g[ iS292{gridDim.y}, iS293{( ceilDiv(8192, gridDim.y) )}, bS301{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) ), 2) ), 1) )}, bS297{blockDim.x}, bS298{2}, bS300{1}, bS295{4} ] )
T27_l[ iblockIdx.y272{gridDim.y}, iS273{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y281{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i4, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x277{blockDim.x}, iS278{2}, iUS280{1}, iS275{4} ] ca_pos( 7 ) produce_pos( 2 )
   = T31_l[ iblockIdx.y302{gridDim.y}, iS303{( ceilDiv(i3, gridDim.y) )}, ithreadIdx.y311{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i4, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x307{blockDim.x}, iS308{2}, iUS310{1}, iV305{4} ] ca_pos( 2 )
   - T32_l[ iblockIdx.y282{gridDim.y}, iS283{( ceilDiv(8192, gridDim.y) )}, bthreadIdx.y291{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) ), 2) ), 1) )}, bthreadIdx.x287{blockDim.x}, bS288{2}, bUS290{1}, bS285{4} ] ca_pos( 2 );
T28_l[ iblockIdx.y262{gridDim.y}, iS263{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y271{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i4, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x267{blockDim.x}, iS268{2}, iUS270{1}, iS265{4} ] ca_pos( 7 ) produce_pos( 7 )
   = T27_l[ iblockIdx.y272{gridDim.y}, iS273{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y281{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i4, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x277{blockDim.x}, iS278{2}, iUS280{1}, iS275{4} ] ca_pos( 7 ) produce_pos( 2 )
   * T33_l[ iblockIdx.y242{gridDim.y}, iS243{( ceilDiv(8192, gridDim.y) )}, bthreadIdx.y251{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) ), 2) ), 1) )}, bthreadIdx.x247{blockDim.x}, bS248{2}, bUS250{1}, bS245{4} ] ca_pos( 2 );
T6_l[ iblockIdx.y412{gridDim.y}, iS413{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y421{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i4, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x417{blockDim.x}, iS418{2}, iUS420{1}, iS415{4} ] ca_pos( 7 ) produce_pos( 2 )
   = T31_l[ iblockIdx.y302{gridDim.y}, iS303{( ceilDiv(i3, gridDim.y) )}, ithreadIdx.y311{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i4, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x307{blockDim.x}, iS308{2}, iUS310{1}, iV305{4} ] ca_pos( 2 )
   - T32_l[ iblockIdx.y282{gridDim.y}, iS283{( ceilDiv(8192, gridDim.y) )}, bthreadIdx.y291{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) ), 2) ), 1) )}, bthreadIdx.x287{blockDim.x}, bS288{2}, bUS290{1}, bS285{4} ] ca_pos( 2 );
T7_l[ iblockIdx.y232{gridDim.y}, iS233{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y241{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i4, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x237{blockDim.x}, iS238{2}, iUS240{1}, iS235{4} ] ca_pos( 7 ) produce_pos( 7 )
   = T6_l[ iblockIdx.y412{gridDim.y}, iS413{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y421{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i4, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x417{blockDim.x}, iS418{2}, iUS420{1}, iS415{4} ] ca_pos( 7 ) produce_pos( 2 )
   * T33_l[ iblockIdx.y242{gridDim.y}, iS243{( ceilDiv(8192, gridDim.y) )}, bthreadIdx.y251{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) ), 2) ), 1) )}, bthreadIdx.x247{blockDim.x}, bS248{2}, bUS250{1}, bS245{4} ] ca_pos( 2 );
T13_l[ iblockIdx.y470{gridDim.y}, iS471{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y479{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x475{blockDim.x}, iS476{2}, iUS478{1}, iS473{4} ] ca_pos( 7 ) produce_pos( 7 )
   = T9_l[ iblockIdx.y174{gridDim.y}, iS175{( ceilDiv(i0, gridDim.y) )}, ithreadIdx.y183{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x179{blockDim.x}, iS180{2}, iUS182{1}, iS177{4} ] ca_pos( 7 ) produce_pos( 7 )
   * T7_l[ iblockIdx.y232{gridDim.y}, iS233{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y241{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i4, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x237{blockDim.x}, iS238{2}, iUS240{1}, iS235{4} ] ca_pos( 7 ) produce_pos( 7 );
T47_l[ iblockIdx.y524{gridDim.y}, iS525{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y533{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}rf, ithreadIdx.x529{blockDim.x}rf, rS530{2}rf, rUS532{1}rf, rS527{4}rf ] ca_pos( 4 ) produce_pos( 7 )
   = reduction( T13_l[ iblockIdx.y470{gridDim.y}, iS471{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y479{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x475{blockDim.x}, iS476{2}, iUS478{1}, iS473{4} ] ca_pos( 7 ) produce_pos( 7 ), op = add, initial value = double(0), allreduce = false )
T14_l[ iblockIdx.y537{gridDim.y}, iS538{( ceilDiv(8192, gridDim.y) )}, rthreadIdx.y535{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, rthreadIdx.x536{blockDim.x} ] ca_pos( 2 ) produce_pos( 4 )
   = reduction( T47_l[ iblockIdx.y524{gridDim.y}, iS525{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y533{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}rf, ithreadIdx.x529{blockDim.x}rf, rS530{2}rf, rUS532{1}rf, rS527{4}rf ] ca_pos( 4 ) produce_pos( 7 ), op = add, initial value = double(0), allreduce = false )
T15_l[ iblockIdx.y332{gridDim.y}, iS333{( ceilDiv(8192, gridDim.y) )}, bthreadIdx.y341{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) ), 2) ), 1) )}, bthreadIdx.x337{blockDim.x}, bS338{2}, bUS340{1}, bS335{4} ] ca_pos( 2 ) produce_pos( 2 )
   = broadcast( T14_l[ iblockIdx.y537{gridDim.y}, iS538{( ceilDiv(8192, gridDim.y) )}, rthreadIdx.y535{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, rthreadIdx.x536{blockDim.x} ] ca_pos( 2 ) produce_pos( 4 ) )
T16_l[ iblockIdx.y322{gridDim.y}, iS323{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y331{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i4, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x327{blockDim.x}, iS328{2}, iUS330{1}, iS325{4} ] ca_pos( 7 ) produce_pos( 7 )
   = T28_l[ iblockIdx.y262{gridDim.y}, iS263{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y271{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i4, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x267{blockDim.x}, iS268{2}, iUS270{1}, iS265{4} ] ca_pos( 7 ) produce_pos( 7 )
   * T15_l[ iblockIdx.y332{gridDim.y}, iS333{( ceilDiv(8192, gridDim.y) )}, bthreadIdx.y341{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) ), 2) ), 1) )}, bthreadIdx.x337{blockDim.x}, bS338{2}, bUS340{1}, bS335{4} ] ca_pos( 2 ) produce_pos( 2 );
T18_l[ iblockIdx.y342{gridDim.y}, iS343{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y351{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x347{blockDim.x}, iS348{2}, iUS350{1}, iS345{4} ] ca_pos( 7 ) produce_pos( 7 )
   = T17_l[ iblockIdx.y352{gridDim.y}, iS353{( ceilDiv(i0, gridDim.y) )}, ithreadIdx.y361{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x357{blockDim.x}, iS358{2}, iUS360{1}, iS355{4} ] ca_pos( 7 ) produce_pos( 7 )
   - T16_l[ iblockIdx.y322{gridDim.y}, iS323{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y331{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i4, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x327{blockDim.x}, iS328{2}, iUS330{1}, iS325{4} ] ca_pos( 7 ) produce_pos( 7 );
T35_l[ iblockIdx.y382{gridDim.y}, iS383{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y391{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x387{blockDim.x}, iS388{2}, iUS390{1}, iS385{4} ] ca_pos( 6 ) produce_pos( 7 )
   = T19_l[ iblockIdx.y402{gridDim.y}, iS403{( ceilDiv(8192, gridDim.y) )}, bthreadIdx.y411{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), blockDim.x) ), 2) ), 1) )}, bthreadIdx.x407{blockDim.x}, bS408{2}, bUS410{1}, bS405{4} ] ca_pos( 2 ) produce_pos( 2 )
   * T18_l[ iblockIdx.y342{gridDim.y}, iS343{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y351{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x347{blockDim.x}, iS348{2}, iUS350{1}, iS345{4} ] ca_pos( 7 ) produce_pos( 7 );
T20_g[ iblockIdx.y392{gridDim.y}, iS393{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y401{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x397{blockDim.x}, iS398{2}, iUS400{1}, iV395{4} ] ca_pos( 6 ) produce_pos( 6 )
   = Set( T35_l[ iblockIdx.y382{gridDim.y}, iS383{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y391{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x387{blockDim.x}, iS388{2}, iUS390{1}, iS385{4} ] ca_pos( 6 ) produce_pos( 7 ) )
T21_l[ iblockIdx.y222{gridDim.y}, iS223{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y231{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x227{blockDim.x}, iS228{2}, iUS230{1}, iS225{4} ] ca_pos( 7 ) produce_pos( 7 )
   = T30_l[ iblockIdx.y194{gridDim.y}, iS195{( ceilDiv(i0, gridDim.y) )}, ithreadIdx.y203{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x199{blockDim.x}, iS200{2}, iUS202{1}, iV197{4} ] ca_pos( 2 )
   * T7_l[ iblockIdx.y232{gridDim.y}, iS233{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y241{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i4, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x237{blockDim.x}, iS238{2}, iUS240{1}, iS235{4} ] ca_pos( 7 ) produce_pos( 7 );
T40_l[ iblockIdx.y102{gridDim.y}rf, rS103{( ceilDiv(8192, gridDim.y) )}rf, ithreadIdx.y439{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x435{blockDim.x}, iS436{2}, iUS438{1}, iS433{4} ] ca_pos( 1 ) produce_pos( 7 )
   = reduction( T21_l[ iblockIdx.y222{gridDim.y}, iS223{( ceilDiv(8192, gridDim.y) )}, ithreadIdx.y231{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x227{blockDim.x}, iS228{2}, iUS230{1}, iS225{4} ] ca_pos( 7 ) produce_pos( 7 ), op = add, initial value = double(0), allreduce = false )
T39_g[ iblockIdx.y106{gridDim.y}, ithreadIdx.y497{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x493{blockDim.x}, iS494{2}, iUS496{1}, iV491{4} ] produce_pos( 1 )
   = Set( T40_l[ iblockIdx.y102{gridDim.y}rf, rS103{( ceilDiv(8192, gridDim.y) )}rf, ithreadIdx.y439{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x435{blockDim.x}, iS436{2}, iUS438{1}, iS433{4} ] ca_pos( 1 ) produce_pos( 7 ) )
T41_l[ iblockIdx.y544{gridDim.y}, ithreadIdx.x542{blockDim.x}, iS543{( ceilDiv(( ceilDiv(( ceilDiv(i2, 2) ), blockDim.x) ), gridDim.y) )}, iS545{( ceilDiv(gridDim.y, blockDim.y) )}, iV540{2}, ithreadIdx.y546{blockDim.y} ] ca_pos( 4 )
   = Set( T39_g[ iblockIdx.y106{gridDim.y}, ithreadIdx.y497{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x493{blockDim.x}, iS494{2}, iUS496{1}, iV491{4} ] produce_pos( 1 ) )
T42_l[ iblockIdx.y125{gridDim.y}, ithreadIdx.x123{blockDim.x}, iS124{( ceilDiv(( ceilDiv(( ceilDiv(i2, 2) ), blockDim.x) ), gridDim.y) )}, rS126{( ceilDiv(gridDim.y, blockDim.y) )}rf, iS121{2}, ithreadIdx.y127{blockDim.y}rf ] ca_pos( 3 ) produce_pos( 4 )
   = reduction( T41_l[ iblockIdx.y544{gridDim.y}, ithreadIdx.x542{blockDim.x}, iS543{( ceilDiv(( ceilDiv(( ceilDiv(i2, 2) ), blockDim.x) ), gridDim.y) )}, iS545{( ceilDiv(gridDim.y, blockDim.y) )}, iV540{2}, ithreadIdx.y546{blockDim.y} ] ca_pos( 4 ), op = add, initial value = double(0), allreduce = false )
T36_l[ iblockIdx.y135{gridDim.y}, ithreadIdx.x133{blockDim.x}, iS134{( ceilDiv(( ceilDiv(( ceilDiv(i2, 2) ), blockDim.x) ), gridDim.y) )}, iS131{2}, rthreadIdx.y128{blockDim.y} ] ca_pos( 3 ) produce_pos( 3 )
   = reduction( T42_l[ iblockIdx.y125{gridDim.y}, ithreadIdx.x123{blockDim.x}, iS124{( ceilDiv(( ceilDiv(( ceilDiv(i2, 2) ), blockDim.x) ), gridDim.y) )}, rS126{( ceilDiv(gridDim.y, blockDim.y) )}rf, iS121{2}, ithreadIdx.y127{blockDim.y}rf ] ca_pos( 3 ) produce_pos( 4 ), op = add, initial value = double(0), allreduce = false )
T22_g[ iblockIdx.y552{gridDim.y}, ithreadIdx.x550{blockDim.x}, iS551{( ceilDiv(( ceilDiv(( ceilDiv(i2, 2) ), blockDim.x) ), gridDim.y) )}, iV548{2} ] ca_pos( 3 ) produce_pos( 3 )
   = Set( T36_l[ iblockIdx.y135{gridDim.y}, ithreadIdx.x133{blockDim.x}, iS134{( ceilDiv(( ceilDiv(( ceilDiv(i2, 2) ), blockDim.x) ), gridDim.y) )}, iS131{2}, rthreadIdx.y128{blockDim.y} ] ca_pos( 3 ) produce_pos( 3 ) )
T44_l[ iblockIdx.y140{gridDim.y}rf, rS141{( ceilDiv(i0, gridDim.y) )}rf, ithreadIdx.y221{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x217{blockDim.x}, iS218{2}, iUS220{1}, iS215{4} ] ca_pos( 1 ) produce_pos( 2 )
   = reduction( T30_l[ iblockIdx.y194{gridDim.y}, iS195{( ceilDiv(i0, gridDim.y) )}, ithreadIdx.y203{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x199{blockDim.x}, iS200{2}, iUS202{1}, iV197{4} ] ca_pos( 2 ), op = add, initial value = double(0), allreduce = false )
T43_g[ iblockIdx.y144{gridDim.y}, ithreadIdx.y505{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x501{blockDim.x}, iS502{2}, iUS504{1}, iV499{4} ] produce_pos( 1 )
   = Set( T44_l[ iblockIdx.y140{gridDim.y}rf, rS141{( ceilDiv(i0, gridDim.y) )}rf, ithreadIdx.y221{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x217{blockDim.x}, iS218{2}, iUS220{1}, iS215{4} ] ca_pos( 1 ) produce_pos( 2 ) )
T45_l[ iblockIdx.y558{gridDim.y}, ithreadIdx.x556{blockDim.x}, iS557{( ceilDiv(( ceilDiv(( ceilDiv(i2, 2) ), blockDim.x) ), gridDim.y) )}, iS559{( ceilDiv(gridDim.y, blockDim.y) )}, iV554{2}, ithreadIdx.y560{blockDim.y} ] ca_pos( 4 )
   = Set( T43_g[ iblockIdx.y144{gridDim.y}, ithreadIdx.y505{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) ), 2) ), 1) )}, ithreadIdx.x501{blockDim.x}, iS502{2}, iUS504{1}, iV499{4} ] produce_pos( 1 ) )
T46_l[ iblockIdx.y163{gridDim.y}, ithreadIdx.x161{blockDim.x}, iS162{( ceilDiv(( ceilDiv(( ceilDiv(i2, 2) ), blockDim.x) ), gridDim.y) )}, rS164{( ceilDiv(gridDim.y, blockDim.y) )}rf, iS159{2}, ithreadIdx.y165{blockDim.y}rf ] ca_pos( 3 ) produce_pos( 4 )
   = reduction( T45_l[ iblockIdx.y558{gridDim.y}, ithreadIdx.x556{blockDim.x}, iS557{( ceilDiv(( ceilDiv(( ceilDiv(i2, 2) ), blockDim.x) ), gridDim.y) )}, iS559{( ceilDiv(gridDim.y, blockDim.y) )}, iV554{2}, ithreadIdx.y560{blockDim.y} ] ca_pos( 4 ), op = add, initial value = double(0), allreduce = false )
T37_l[ iblockIdx.y173{gridDim.y}, ithreadIdx.x171{blockDim.x}, iS172{( ceilDiv(( ceilDiv(( ceilDiv(i2, 2) ), blockDim.x) ), gridDim.y) )}, iS169{2}, rthreadIdx.y166{blockDim.y} ] ca_pos( 3 ) produce_pos( 3 )
   = reduction( T46_l[ iblockIdx.y163{gridDim.y}, ithreadIdx.x161{blockDim.x}, iS162{( ceilDiv(( ceilDiv(( ceilDiv(i2, 2) ), blockDim.x) ), gridDim.y) )}, rS164{( ceilDiv(gridDim.y, blockDim.y) )}rf, iS159{2}, ithreadIdx.y165{blockDim.y}rf ] ca_pos( 3 ) produce_pos( 4 ), op = add, initial value = double(0), allreduce = false )
T23_g[ iblockIdx.y566{gridDim.y}, ithreadIdx.x564{blockDim.x}, iS565{( ceilDiv(( ceilDiv(( ceilDiv(i2, 2) ), blockDim.x) ), gridDim.y) )}, iV562{2} ] ca_pos( 3 ) produce_pos( 3 )
   = Set( T37_l[ iblockIdx.y173{gridDim.y}, ithreadIdx.x171{blockDim.x}, iS172{( ceilDiv(( ceilDiv(( ceilDiv(i2, 2) ), blockDim.x) ), gridDim.y) )}, iS169{2}, rthreadIdx.y166{blockDim.y} ] ca_pos( 3 ) produce_pos( 3 ) )
}
liqiangxl commented 1 year ago

Use the latest code at https://github.com/NVIDIA/Fuser/tree/repro_persistent_projection Run with: /opt/pytorch/nvfuser/build/nvfuser_bench --benchmark_min_time=0.01 --benchmark_filter=NvFuserScheduler_LayerNorm_BWD_Project_fp32 2>&1 |tee 1.log

Enforce projection:

---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                                                       Time             CPU   Iterations UserCounters...
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
NvFuserScheduler_LayerNorm_BWD_Project_fp32___GRAPH/NvFuserScheduler_LayerNorm_BWD_Project_fp32/8192/8192/manual_time         596 us          675 us           23 bytes_per_second=1.22935T/s Red On Fastest Dim // Persistent Kernel // Project Persistent Buffers //  // Iteration Domain: split grid dimension /  // Inner Reduction Domain: persistent batch - 10 / vectorize / factor 4/Launch_Parameters[block(1/7/32)/grid(1/108/1)/896]
NvFuserScheduler_LayerNorm_BWD_Project_fp32___GRAPH/NvFuserScheduler_LayerNorm_BWD_Project_fp32/8192/10240/manual_time        734 us          840 us           19 bytes_per_second=1.24815T/s Red On Fastest Dim // Persistent Kernel // Project Persistent Buffers //  // Iteration Domain: split grid dimension /  // Inner Reduction Domain: persistent batch - 10 / vectorize / factor 4/Launch_Parameters[block(1/8/32)/grid(1/108/1)/1024]
NvFuserScheduler_LayerNorm_BWD_Project_fp32___GRAPH/NvFuserScheduler_LayerNorm_BWD_Project_fp32/8192/12288/manual_time        877 us          965 us           16 bytes_per_second=1.25326T/s Red On Fastest Dim // Persistent Kernel // Project Persistent Buffers //  // Iteration Domain: split grid dimension /  // Inner Reduction Domain: persistent batch - 12 / vectorize / factor 4/Launch_Parameters[block(1/8/32)/grid(1/108/1)/1024]
NvFuserScheduler_LayerNorm_BWD_Project_fp32___GRAPH/NvFuserScheduler_LayerNorm_BWD_Project_fp32/8192/14336/manual_time       1089 us         1170 us           13 bytes_per_second=1.17758T/s Red On Fastest Dim // Persistent Kernel // Project Persistent Buffers //  // Iteration Domain: split grid dimension /  // Inner Reduction Domain: persistent batch - 14 / vectorize / factor 4/Launch_Parameters[block(1/4/64)/grid(1/108/1)/1024]

without projection (comment out L1355 to L1369 in normalization.cpp):

---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                                                       Time             CPU   Iterations UserCounters...
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
NvFuserScheduler_LayerNorm_BWD_Project_fp32___GRAPH/NvFuserScheduler_LayerNorm_BWD_Project_fp32/8192/8192/manual_time        1011 us         1092 us           14 bytes_per_second=742.297G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain: split grid dimension /  // Inner Reduction Domain: persistent batch - 13 / vectorize / factor 4/Launch_Parameters[block(1/5/32)/grid(1/108/1)/640]
NvFuserScheduler_LayerNorm_BWD_Project_fp32___GRAPH/NvFuserScheduler_LayerNorm_BWD_Project_fp32/8192/10240/manual_time       1074 us         1170 us           13 bytes_per_second=872.762G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain: split grid dimension /  // Inner Reduction Domain: persistent batch - 12 / vectorize / factor 4/Launch_Parameters[block(1/7/32)/grid(1/108/1)/896]
NvFuserScheduler_LayerNorm_BWD_Project_fp32___GRAPH/NvFuserScheduler_LayerNorm_BWD_Project_fp32/8192/12288/manual_time       1191 us         1272 us           12 bytes_per_second=944.782G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain: split grid dimension /  // Inner Reduction Domain: persistent batch - 12 / vectorize / factor 4/Launch_Parameters[block(1/8/32)/grid(1/108/1)/1024]
NvFuserScheduler_LayerNorm_BWD_Project_fp32___GRAPH/NvFuserScheduler_LayerNorm_BWD_Project_fp32/8192/14336/manual_time       1378 us         1467 us           10 bytes_per_second=952.591G/s Red On Fastest Dim // Persistent Kernel //  // Iteration Domain: split grid dimension /  // Inner Reduction Domain: persistent batch - 14 / vectorize / factor 4/Launch_Parameters[block(1/4/64)/grid(1/108/1)/1024]