NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
257 stars 51 forks source link

python_tests.test_normalization.test_instance_norm_multigpu failure #1728

Open naoyam opened 8 months ago

naoyam commented 8 months ago

This test started failing a couple of days ago.

python_tests.test_normalization.test_instance_norm_multigpu

See also https://gitlab-master.nvidia.com/dl/pytorch/update-scripts/-/issues/50225

jacobhinkle commented 8 months ago

Failure is not dependent on multi-gpu. Here's a single-device repro:

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1, -1, -1], contiguity=[True, True, True, True, True], dtype=DataType.Float, is_cpu=False)
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False)
    T2 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False)
    T3 = fd.define_tensor(shape=[-1, -1, -1, -1, -1], contiguity=[True, None, None, None, None], dtype=DataType.Float, is_cpu=False)
    S4, S5, S6, S7, S8, = fd.ops.tensor_sizes(T0)
    S9 = fd.define_scalar(1, dtype=DataType.Int)
    S10 = fd.ops.mul(S9, S6)
    S11 = fd.ops.mul(S10, S7)
    S12 = fd.ops.mul(S11, S8)
    T13 = fd.ops.broadcast(T1, is_broadcast_dim=[False, False, True, True, True])
    S14 = fd.ops.reciprocal(S12)
    T15 = fd.ops.sum(T3, axes=[2, 3, 4], keepdim=False, dtype=DataType.Null)
    T16 = fd.ops.sub(T0, T13)
    T17 = fd.ops.mul(T3, T16)
    T18 = fd.ops.sum(T17, axes=[2, 3, 4], keepdim=False, dtype=DataType.Null)
    T19 = fd.ops.mul(T15, S14)
    T20 = fd.ops.broadcast(T19, is_broadcast_dim=[False, False, True, True, True])
    T21 = fd.ops.mul(T18, S14)
    T22 = fd.ops.mul(T2, T2)
    T23 = fd.ops.mul(T21, T22)
    T24 = fd.ops.broadcast(T23, is_broadcast_dim=[False, False, True, True, True])
    T25 = fd.ops.broadcast(T2, is_broadcast_dim=[False, False, True, True, True])
    T26 = fd.ops.sub(T0, T13)
    T27 = fd.ops.mul(T26, T24)
    T28 = fd.ops.sub(T3, T27)
    T29 = fd.ops.sub(T28, T20)
    T30 = fd.ops.mul(T29, T25)
    fd.add_output(T30)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn((16777216,), dtype=torch.float32, device='cuda:0').as_strided((2, 4, 128, 128, 128), (8388608, 2097152, 16384, 128, 1)),
    torch.randn((8,), dtype=torch.float32, device='cuda:0').as_strided((2, 4), (4, 1)),
    torch.randn((8,), dtype=torch.float32, device='cuda:0').as_strided((2, 4), (4, 1)),
    torch.randn((2,), dtype=torch.float32, device='cuda:0').as_strided((2, 4, 128, 128, 128), (1, 0, 0, 0, 0)),
]
fd.execute(inputs)

RuntimeError: producer->getMemoryType() == MemoryType::Global INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/device_lower/analysis/sync_information.cpp":763, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Inconsistent parallelization found between TV22 (T22_l[ bblockIdx.x239{( ceilDiv(( ceilDiv(( 1 ( 1 1 ) ), 4) ), blockDim.x) ) ex ( ceilDiv(( ceilDiv(( i12 ( i13 i14 ) ), 4) ), blockDim.x) )}, iblockIdx.y242{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) 1 ), 1) )}, iUS243{1}, bS238{4}, bthreadIdx.x240{blockDim.x} ]) and TV19(T19_l[ iblockIdx.x194{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T17) )).logical_size ))[2] ) ( ( (( (( getMetaData(T17) )).logical_size ))[3] ) ( (( (( getMetaData(T17) )).logical_size ))[4] ) ) ), 4) ), blockDim.x) )}, iblockIdx.y197{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS198{1}, iS193{4}, ithreadIdx.x195{blockDim.x} ] ca_pos( 5 ) produce_pos( 5 )). Producer is required to be in Global Memory based on parallelization strategy. RAW flags: (blockIdx.y)

Scheduled Fusion IR for this segment is

Inputs:
  T3_g[ bS248{( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), blockDim.x) ) ex ( ceilDiv(( ceilDiv(( i12 * ( i13 * i14 ) ), 4) ), blockDim.x) )}, iS251{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * 1 ), 1) )}, iS252{1}, bS247{4}, bS249{blockDim.x} ], float
  T2_g[ iS281{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iS282{1} ], float
  T9_g[ iS272{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iS273{1}, rS337{( (( (( getMetaData(T17) )).logical_size ))[2] )}, rS338{( (( (( getMetaData(T17) )).logical_size ))[3] )}, rS339{( (( (( getMetaData(T17) )).logical_size ))[4] )} ], float
  T17_g[ iS230{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T17) )).logical_size ))[2] ) * ( ( (( (( getMetaData(T17) )).logical_size ))[3] ) * ( (( (( getMetaData(T17) )).logical_size ))[4] ) ) ), 4) ), blockDim.x) )}, iS233{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iS234{1}, iS229{4}, iS231{blockDim.x} ], float
Outputs:
  T21_g[ iblockIdx.x150{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T17) )).logical_size ))[2] ) * ( ( (( (( getMetaData(T17) )).logical_size ))[3] ) * ( (( (( getMetaData(T17) )).logical_size ))[4] ) ) ), 4) ), blockDim.x) )}, iblockIdx.y152{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS153{1}, iV149{4}, ithreadIdx.x151{blockDim.x} ] ca_pos( 3 ) produce_pos( 3 ), float

%kernel_math {
T22_l[ bblockIdx.x239{( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), blockDim.x) ) ex ( ceilDiv(( ceilDiv(( i12 * ( i13 * i14 ) ), 4) ), blockDim.x) )}, iblockIdx.y242{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * 1 ), 1) )}, iUS243{1}, bS238{4}, bthreadIdx.x240{blockDim.x} ]
   = Set( T3_g[ bS248{( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), blockDim.x) ) ex ( ceilDiv(( ceilDiv(( i12 * ( i13 * i14 ) ), 4) ), blockDim.x) )}, iS251{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * 1 ), 1) )}, iS252{1}, bS247{4}, bS249{blockDim.x} ], cache_op=AllLevels )
T25_l[ iblockIdx.x221{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T17) )).logical_size ))[2] ) * ( ( (( (( getMetaData(T17) )).logical_size ))[3] ) * ( (( (( getMetaData(T17) )).logical_size ))[4] ) ) ), 4) ), blockDim.x) )}, iblockIdx.y224{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS225{1}, iV220{4}, ithreadIdx.x222{blockDim.x} ] ca_pos( 3 )
   = Set( T17_g[ iS230{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T17) )).logical_size ))[2] ) * ( ( (( (( getMetaData(T17) )).logical_size ))[3] ) * ( (( (( getMetaData(T17) )).logical_size ))[4] ) ) ), 4) ), blockDim.x) )}, iS233{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iS234{1}, iS229{4}, iS231{blockDim.x} ], cache_op=Streaming )
T24_l[ iblockIdx.y269{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS270{1} ] ca_pos( 2 )
   = Set( T9_g[ iS272{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iS273{1}, rS337{( (( (( getMetaData(T17) )).logical_size ))[2] )}, rS338{( (( (( getMetaData(T17) )).logical_size ))[3] )}, rS339{( (( (( getMetaData(T17) )).logical_size ))[4] )} ], cache_op=AllLevels )
s1027 = getMetaData(T17_g[ iS230{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T17) )).logical_size ))[2] ) * ( ( (( (( getMetaData(T17) )).logical_size ))[3] ) * ( (( (( getMetaData(T17) )).logical_size ))[4] ) ) ), 4) ), blockDim.x) )}, iS233{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iS234{1}, iS229{4}, iS231{blockDim.x} ])
a1028 = s1027.logical_size
i1030 = a1028[2]
i17 = 1 * i1030;
a1032 = s1027.logical_size
i1034 = a1032[3]
i19 = i17 * i1034;
a1036 = s1027.logical_size
i1038 = a1036[4]
i21 = i19 * i1038;
f24 = (float)(i21);
f26 = reciprocal(f24);
T12_l[ iblockIdx.y266{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS267{1} ] ca_pos( 2 ) produce_pos( 2 )
   = T24_l[ iblockIdx.y269{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS270{1} ] ca_pos( 2 )
   * f26;
T23_l[ iblockIdx.y278{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS279{1} ] ca_pos( 2 )
   = Set( T2_g[ iS281{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iS282{1} ], cache_op=AllLevels )
T13_l[ iblockIdx.y263{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS264{1} ] ca_pos( 2 ) produce_pos( 2 )
   = T23_l[ iblockIdx.y278{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS279{1} ] ca_pos( 2 )
   * T23_l[ iblockIdx.y278{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS279{1} ] ca_pos( 2 );
T14_l[ iblockIdx.y260{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS261{1} ] ca_pos( 2 ) produce_pos( 2 )
   = T12_l[ iblockIdx.y266{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS267{1} ] ca_pos( 2 ) produce_pos( 2 )
   * T13_l[ iblockIdx.y263{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS264{1} ] ca_pos( 2 ) produce_pos( 2 );
T15_l[ bblockIdx.x212{( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), blockDim.x) )}, iblockIdx.y215{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS216{1}, bS211{4}, bthreadIdx.x213{blockDim.x} ] ca_pos( 3 ) produce_pos( 3 )
   = broadcast( T14_l[ iblockIdx.y260{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS261{1} ] ca_pos( 2 ) produce_pos( 2 ) )
T18_l[ iblockIdx.x203{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T17) )).logical_size ))[2] ) * ( ( (( (( getMetaData(T17) )).logical_size ))[3] ) * ( (( (( getMetaData(T17) )).logical_size ))[4] ) ) ), 4) ), blockDim.x) )}, iblockIdx.y206{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS207{1}, iS202{4}, ithreadIdx.x204{blockDim.x} ] ca_pos( 5 ) produce_pos( 3 )
   = T25_l[ iblockIdx.x221{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T17) )).logical_size ))[2] ) * ( ( (( (( getMetaData(T17) )).logical_size ))[3] ) * ( (( (( getMetaData(T17) )).logical_size ))[4] ) ) ), 4) ), blockDim.x) )}, iblockIdx.y224{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS225{1}, iV220{4}, ithreadIdx.x222{blockDim.x} ] ca_pos( 3 )
   * T15_l[ bblockIdx.x212{( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), blockDim.x) )}, iblockIdx.y215{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS216{1}, bS211{4}, bthreadIdx.x213{blockDim.x} ] ca_pos( 3 ) produce_pos( 3 );
T19_l[ iblockIdx.x194{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T17) )).logical_size ))[2] ) * ( ( (( (( getMetaData(T17) )).logical_size ))[3] ) * ( (( (( getMetaData(T17) )).logical_size ))[4] ) ) ), 4) ), blockDim.x) )}, iblockIdx.y197{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS198{1}, iS193{4}, ithreadIdx.x195{blockDim.x} ] ca_pos( 5 ) produce_pos( 5 )
   = T22_l[ bblockIdx.x239{( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), blockDim.x) ) ex ( ceilDiv(( ceilDiv(( i12 * ( i13 * i14 ) ), 4) ), blockDim.x) )}, iblockIdx.y242{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * 1 ), 1) )}, iUS243{1}, bS238{4}, bthreadIdx.x240{blockDim.x} ]
   - T18_l[ iblockIdx.x203{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T17) )).logical_size ))[2] ) * ( ( (( (( getMetaData(T17) )).logical_size ))[3] ) * ( (( (( getMetaData(T17) )).logical_size ))[4] ) ) ), 4) ), blockDim.x) )}, iblockIdx.y206{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS207{1}, iS202{4}, ithreadIdx.x204{blockDim.x} ] ca_pos( 5 ) produce_pos( 3 );
T5_l[ iblockIdx.y254{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * 1 ), 1) )}, iUS255{1} ] ca_pos( 2 )
   = squeeze( T22_l[ bblockIdx.x239{( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), blockDim.x) ) ex ( ceilDiv(( ceilDiv(( i12 * ( i13 * i14 ) ), 4) ), blockDim.x) )}, iblockIdx.y242{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * 1 ), 1) )}, iUS243{1}, bS238{4}, bthreadIdx.x240{blockDim.x} ] )
s201 = getMetaData(T3_g[ bS248{( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), blockDim.x) ) ex ( ceilDiv(( ceilDiv(( i12 * ( i13 * i14 ) ), 4) ), blockDim.x) )}, iS251{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * 1 ), 1) )}, iS252{1}, bS247{4}, bS249{blockDim.x} ])
a1015 = s201.logical_size
i1017 = a1015[2]
a1019 = s201.logical_size
i1021 = a1019[3]
i30 = i1017 * i1021;
a1023 = s201.logical_size
i1025 = a1023[4]
i32 = i30 * i1025;
f34 = (float)(i32);
T6_l[ iblockIdx.y257{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * 1 ), 1) )}, iUS258{1} ] ca_pos( 2 ) produce_pos( 2 )
   = T5_l[ iblockIdx.y254{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * 1 ), 1) )}, iUS255{1} ] ca_pos( 2 )
   * f34;
T10_l[ iblockIdx.y275{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * 1 ), 1) )}, iUS276{1} ] ca_pos( 2 ) produce_pos( 2 )
   = T6_l[ iblockIdx.y257{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * 1 ), 1) )}, iUS258{1} ] ca_pos( 2 ) produce_pos( 2 )
   * f26;
T11_l[ bblockIdx.x185{( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), blockDim.x) )}, iblockIdx.y188{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * 1 ), 1) )}, iUS189{1}, bS184{4}, bthreadIdx.x186{blockDim.x} ] ca_pos( 3 ) produce_pos( 3 )
   = broadcast( T10_l[ iblockIdx.y275{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * 1 ), 1) )}, iUS276{1} ] ca_pos( 2 ) produce_pos( 2 ) )
T20_l[ iblockIdx.x176{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T17) )).logical_size ))[2] ) * ( ( (( (( getMetaData(T17) )).logical_size ))[3] ) * ( (( (( getMetaData(T17) )).logical_size ))[4] ) ) ), 4) ), blockDim.x) )}, iblockIdx.y179{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS180{1}, iS175{4}, ithreadIdx.x177{blockDim.x} ] ca_pos( 5 ) produce_pos( 5 )
   = T19_l[ iblockIdx.x194{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T17) )).logical_size ))[2] ) * ( ( (( (( getMetaData(T17) )).logical_size ))[3] ) * ( (( (( getMetaData(T17) )).logical_size ))[4] ) ) ), 4) ), blockDim.x) )}, iblockIdx.y197{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS198{1}, iS193{4}, ithreadIdx.x195{blockDim.x} ] ca_pos( 5 ) produce_pos( 5 )
   - T11_l[ bblockIdx.x185{( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), blockDim.x) )}, iblockIdx.y188{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * 1 ), 1) )}, iUS189{1}, bS184{4}, bthreadIdx.x186{blockDim.x} ] ca_pos( 3 ) produce_pos( 3 );
T16_l[ bblockIdx.x167{( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), blockDim.x) )}, iblockIdx.y170{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS171{1}, bS166{4}, bthreadIdx.x168{blockDim.x} ] ca_pos( 3 ) produce_pos( 3 )
   = broadcast( T23_l[ iblockIdx.y278{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS279{1} ] ca_pos( 2 ) )
T26_l[ iblockIdx.x158{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T17) )).logical_size ))[2] ) * ( ( (( (( getMetaData(T17) )).logical_size ))[3] ) * ( (( (( getMetaData(T17) )).logical_size ))[4] ) ) ), 4) ), blockDim.x) )}, iblockIdx.y161{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS162{1}, iS157{4}, ithreadIdx.x159{blockDim.x} ] ca_pos( 3 ) produce_pos( 5 )
   = T20_l[ iblockIdx.x176{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T17) )).logical_size ))[2] ) * ( ( (( (( getMetaData(T17) )).logical_size ))[3] ) * ( (( (( getMetaData(T17) )).logical_size ))[4] ) ) ), 4) ), blockDim.x) )}, iblockIdx.y179{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS180{1}, iS175{4}, ithreadIdx.x177{blockDim.x} ] ca_pos( 5 ) produce_pos( 5 )
   * T16_l[ bblockIdx.x167{( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), blockDim.x) )}, iblockIdx.y170{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS171{1}, bS166{4}, bthreadIdx.x168{blockDim.x} ] ca_pos( 3 ) produce_pos( 3 );
T21_g[ iblockIdx.x150{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T17) )).logical_size ))[2] ) * ( ( (( (( getMetaData(T17) )).logical_size ))[3] ) * ( (( (( getMetaData(T17) )).logical_size ))[4] ) ) ), 4) ), blockDim.x) )}, iblockIdx.y152{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS153{1}, iV149{4}, ithreadIdx.x151{blockDim.x} ] ca_pos( 3 ) produce_pos( 3 )
   = Set( T26_l[ iblockIdx.x158{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T17) )).logical_size ))[2] ) * ( ( (( (( getMetaData(T17) )).logical_size ))[3] ) * ( (( (( getMetaData(T17) )).logical_size ))[4] ) ) ), 4) ), blockDim.x) )}, iblockIdx.y161{( ceilDiv(( ( (( (( getMetaData(T2) )).logical_size ))[0] ) * ( (( (( getMetaData(T2) )).logical_size ))[1] ) ), 1) )}, iUS162{1}, iS157{4}, ithreadIdx.x159{blockDim.x} ] ca_pos( 3 ) produce_pos( 5 ), cache_op=Streaming )
}
jacobhinkle commented 8 months ago

T3 in the python repro is an expanded input in the last three dims, and we are now translating the sum T15 to a squeeze. Unscheduled fusion IR:

Inputs:
  T0_g[ iS0{i0}, iS1{i1}, iS2{i2}, iS3{i3}, iS4{i4} ], float
  T1_g[ iS87{i0}, iS88{i1} ], float
  T2_g[ iS94{i0}, iS95{i1} ], float
  T3_g[ iS86{i0}, bS10{1 ex i11}, bS11{1 ex i12}, bS12{1 ex i13}, bS13{1 ex i14} ], float
Outputs:
  T21_g[ iS108{i0}, iS82{i1}, iS83{i2}, iS84{i3}, iS85{i4} ], float

%kernel_math {
T4_l[ iS89{i0}, iS90{i1}, bS16{1}, bS17{1}, bS18{1} ]
   = broadcast( T1_g[ iS87{i0}, iS88{i1} ] )
T17_l[ iS61{i0}, iS62{i1}, iS63{i2}, iS64{i3}, iS65{i4} ]
   = T0_g[ iS0{i0}, iS1{i1}, iS2{i2}, iS3{i3}, iS4{i4} ]
   - T4_l[ iS89{i0}, iS90{i1}, bS16{1}, bS17{1}, bS18{1} ];
T7_l[ iS23{i0}, iS24{i1}, iS25{i2}, iS26{i3}, iS27{i4} ]
   = T0_g[ iS0{i0}, iS1{i1}, iS2{i2}, iS3{i3}, iS4{i4} ]
   - T4_l[ iS89{i0}, iS90{i1}, bS16{1}, bS17{1}, bS18{1} ];
T8_l[ iS91{i0}, iS29{i1}, iS30{i2}, iS31{i3}, iS32{i4} ]
   = T3_g[ iS86{i0}, bS10{1 ex i11}, bS11{1 ex i12}, bS12{1 ex i13}, bS13{1 ex i14} ]
   * T7_l[ iS23{i0}, iS24{i1}, iS25{i2}, iS26{i3}, iS27{i4} ];
T9_l[ iS92{i0}, iS34{i1}, rS35{i2}, rS36{i3}, rS37{i4} ]
   = reduction( T8_l[ iS91{i0}, iS29{i1}, iS30{i2}, iS31{i3}, iS32{i4} ], op = add, initial value = float(0), allreduce = false )
i17 = 1 * i2;
i19 = i17 * i3;
i21 = i19 * i4;
f24 = (float)(i21);
f26 = reciprocal(f24);
T12_l[ iS93{i0}, iS46{i1} ]
   = T9_l[ iS92{i0}, iS34{i1}, rS35{i2}, rS36{i3}, rS37{i4} ]
   * f26;
T13_l[ iS96{i0}, iS97{i1} ]
   = T2_g[ iS94{i0}, iS95{i1} ]
   * T2_g[ iS94{i0}, iS95{i1} ];
T14_l[ iS98{i0}, iS50{i1} ]
   = T12_l[ iS93{i0}, iS46{i1} ]
   * T13_l[ iS96{i0}, iS97{i1} ];
T15_l[ iS99{i0}, iS52{i1}, bS53{1}, bS54{1}, bS55{1} ]
   = broadcast( T14_l[ iS98{i0}, iS50{i1} ] )
T18_l[ iS66{i0}, iS67{i1}, iS68{i2}, iS69{i3}, iS70{i4} ]
   = T17_l[ iS61{i0}, iS62{i1}, iS63{i2}, iS64{i3}, iS65{i4} ]
   * T15_l[ iS99{i0}, iS52{i1}, bS53{1}, bS54{1}, bS55{1} ];
T19_l[ iS100{i0}, iS72{i1}, iS73{i2}, iS74{i3}, iS75{i4} ]
   = T3_g[ iS86{i0}, bS10{1 ex i11}, bS11{1 ex i12}, bS12{1 ex i13}, bS13{1 ex i14} ]
   - T18_l[ iS66{i0}, iS67{i1}, iS68{i2}, iS69{i3}, iS70{i4} ];
T5_l[ iS101{i0}, bS20{1 ex i11} ]
   = squeeze( T3_g[ iS86{i0}, bS10{1 ex i11}, bS11{1 ex i12}, bS12{1 ex i13}, bS13{1 ex i14} ] )
i30 = i12 * i13;
i32 = i30 * i14;
f34 = (float)(i32);
T6_l[ iS102{i0}, bS22{1 ex i11} ]
   = T5_l[ iS101{i0}, bS20{1 ex i11} ]
   * f34;
T10_l[ iS103{i0}, bS39{1 ex i11} ]
   = T6_l[ iS102{i0}, bS22{1 ex i11} ]
   * f26;
T11_l[ iS104{i0}, bS41{1 ex i11}, bS42{1}, bS43{1}, bS44{1} ]
   = broadcast( T10_l[ iS103{i0}, bS39{1 ex i11} ] )
T20_l[ iS105{i0}, iS77{i1}, iS78{i2}, iS79{i3}, iS80{i4} ]
   = T19_l[ iS100{i0}, iS72{i1}, iS73{i2}, iS74{i3}, iS75{i4} ]
   - T11_l[ iS104{i0}, bS41{1 ex i11}, bS42{1}, bS43{1}, bS44{1} ];
T16_l[ iS106{i0}, iS107{i1}, bS58{1}, bS59{1}, bS60{1} ]
   = broadcast( T2_g[ iS94{i0}, iS95{i1} ] )
T21_g[ iS108{i0}, iS82{i1}, iS83{i2}, iS84{i3}, iS85{i4} ]
   = T20_l[ iS105{i0}, iS77{i1}, iS78{i2}, iS79{i3}, iS80{i4} ]
   * T16_l[ iS106{i0}, iS107{i1}, bS58{1}, bS59{1}, bS60{1} ];
}

Although the error occurs at T19_l which is a sibling of T5_l, it could still have something to do with the squeeze. Maybe the squeeze of expanded broadcasts messed up transform/parallelization propagation in this case?

jjsjann123 commented 8 months ago

just an FYI, I temporarily skipped the test just to we don't have noisy CI.

jacobhinkle commented 8 months ago

just an FYI, I temporarily skipped the test just to we don't have noisy CI.

Thank you!

jacobhinkle commented 8 months ago

This bug doesn't have to do with squeeze of expanded dims, or with reductions. It was exposed by #1679 but is really just an issue with squeeze+broadcast.

Here is a close to minimal repro:

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
    T0 = fd.define_tensor(
        shape=[-1, -1, -1, -1],
        contiguity=[True, True, True, True],
        dtype=DataType.Float,
        is_cpu=False,
    )
    T1 = fd.define_tensor(
        shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False
    )
    T2 = fd.define_tensor(
        shape=[-1, 1, 1, 1],
        contiguity=[True, None, None, None],
        dtype=DataType.Float,
        is_cpu=False,
    )
    #        T1         T2
    #        b|       s/ |
    #   T0   T5     T3   /
    #     \  /     b/   /
    #      T6     T4   /
    #        \   /    /
    #         T7     /
    #           \   /
    #            T8
    T3 = fd.ops.squeeze(T2, dims=[2, 3])
    T4 = fd.ops.broadcast(T3, is_broadcast_dim=[False, False, True, True])
    T5 = fd.ops.broadcast(T1, is_broadcast_dim=[False, False, True, True])
    T6 = fd.ops.mul(T0, T5)
    T7 = fd.ops.mul(T6, T4)
    T8 = fd.ops.mul(T2, T7)
    fd.add_output(T8)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn((131072,), dtype=torch.float32, device="cuda:0").as_strided(
        (2, 4, 128, 128), (65536, 16384, 128, 1)
    ),
    torch.randn((8,), dtype=torch.float32, device="cuda:0").as_strided((2, 4), (4, 1)),
    torch.randn((2,), dtype=torch.float32, device="cuda:0").as_strided(
        (2, 1, 1, 1), (1, 1, 1, 1)
    ),
]
fd.execute(inputs)

I verified that this would have been failing on main before the merge of #1679 (tested with 110cb3a).

naoyam commented 8 months ago

Are you sure the repro is correct? T4 and T5 don't seem to be used.

jacobhinkle commented 8 months ago

Oh, I copied it incorrectly. But the implicit broadcasts will introduce those tensors anyway. I'll update the repro in the comment above.

naoyam commented 8 months ago

Tried the repro quickly, it seems wrong that blockDim.y is recognized that a RAW sync is required for.

jacobhinkle commented 8 months ago

Yeah useSameIndex is failing there with this check https://github.com/NVIDIA/Fuser/blob/main/csrc/device_lower/analysis/sync_information.cpp#L425.

naoyam commented 8 months ago

C++ repro:

TEST_F(NVFuserTest, TMP) {
  auto fusion_ptr = std::make_unique<Fusion>();
  auto& fusion = *fusion_ptr;
  FusionGuard fg(&fusion);

  auto tv0 = makeSymbolicTensor(4);
  fusion.addInput(tv0);
  auto tv1 = makeSymbolicTensor(2);
  fusion.addInput(tv1);
  auto tv2 = makeConcreteTensor({-1, 1, 1, 1});
  fusion.addInput(tv2);

  auto tv3 = squeeze(tv2, std::vector<int64_t>{2, 3});
  auto tv4 = broadcast(tv3, {false, false, true, true});
  auto tv5 = broadcast(tv1, {false, false, true, true});
  auto tv6 = mul(tv0, tv5);
  auto tv7 = mul(tv6, tv4);
  auto tv8 = mul(tv2, tv7);
  fusion.addOutput(tv8);

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  auto t0 = at::randn({2, 4, 128,128}, options);
  auto t1 = at::randn({2, 4}, options);
  auto t2 = at::randn({2, 1, 1, 1}, options);
  std::vector<c10::IValue> aten_inputs({t0, t1, t2});

  FusionExecutorCache fec(std::move(fusion_ptr));
  auto cg_outputs = fec.runFusionWithInputs(aten_inputs);
}
naoyam commented 8 months ago
%kernel {
T11_l[ bblockIdx.x131{( ceilDiv(( ceilDiv(( 1 * 1 ), 4) ), blockDim.x) )}, iblockIdx.y134{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * 1 ), 1) )}, iUS135{1}, bS130{4}, bthreadIdx.x132{blockDim.x} ]
   = Set( T2_g[ bS139{( ceilDiv(( ceilDiv(( 1 * 1 ), 4) ), blockDim.x) )}, iS142{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * 1 ), 1) )}, iS143{1}, bS138{4}, bS140{blockDim.x} ], cache_op=AllLevels )
T9_l[ iblockIdx.x115{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[2] ) * ( (( (( getMetaData(T0) )).logical_size ))[3] ) ), 4) ), blockDim.x) )}, iblockIdx.y118{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ), 1) )}, iUS119{1}, iS114{4}, ithreadIdx.x116{blockDim.x} ] ca_pos( 3 )
   = Set( T0_g[ iS123{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[2] ) * ( (( (( getMetaData(T0) )).logical_size ))[3] ) ), 4) ), blockDim.x) )}, iS126{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ), 1) )}, iS127{1}, iS122{4}, iS124{blockDim.x} ], cache_op=Streaming )
T10_l[ iblockIdx.y145{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ), 1) )}, iUS146{1} ] ca_pos( 2 )
   = Set( T1_g[ iS148{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ), 1) )}, iS149{1} ], cache_op=AllLevels )
T5_l[ bblockIdx.x107{( ceilDiv(( ceilDiv(( 1 * 1 ), 4) ), blockDim.x) )}, iblockIdx.y110{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ), 1) )}, iUS111{1}, bS106{4}, bthreadIdx.x108{blockDim.x} ] ca_pos( 3 ) produce_pos( 3 )
   = broadcast( T10_l[ iblockIdx.y145{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ), 1) )}, iUS146{1} ] ca_pos( 2 ) )
T6_l[ iblockIdx.x99{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[2] ) * ( (( (( getMetaData(T0) )).logical_size ))[3] ) ), 4) ), blockDim.x) )}, iblockIdx.y102{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ), 1) )}, iUS103{1}, iS98{4}, ithreadIdx.x100{blockDim.x} ] ca_pos( 5 ) produce_pos( 3 )
   = T9_l[ iblockIdx.x115{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[2] ) * ( (( (( getMetaData(T0) )).logical_size ))[3] ) ), 4) ), blockDim.x) )}, iblockIdx.y118{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ), 1) )}, iUS119{1}, iS114{4}, ithreadIdx.x116{blockDim.x} ] ca_pos( 3 )
   * T5_l[ bblockIdx.x107{( ceilDiv(( ceilDiv(( 1 * 1 ), 4) ), blockDim.x) )}, iblockIdx.y110{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ), 1) )}, iUS111{1}, bS106{4}, bthreadIdx.x108{blockDim.x} ] ca_pos( 3 ) produce_pos( 3 );
T3_l[ iblockIdx.y151{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * 1 ), 1) )}, iUS152{1} ] ca_pos( 2 )
   = squeeze( T11_l[ bblockIdx.x131{( ceilDiv(( ceilDiv(( 1 * 1 ), 4) ), blockDim.x) )}, iblockIdx.y134{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * 1 ), 1) )}, iUS135{1}, bS130{4}, bthreadIdx.x132{blockDim.x} ] )
T4_l[ bblockIdx.x91{( ceilDiv(( ceilDiv(( 1 * 1 ), 4) ), blockDim.x) )}, iblockIdx.y94{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * 1 ), 1) )}, iUS95{1}, bS90{4}, bthreadIdx.x92{blockDim.x} ] ca_pos( 3 ) produce_pos( 3 )
   = broadcast( T3_l[ iblockIdx.y151{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * 1 ), 1) )}, iUS152{1} ] ca_pos( 2 ) )
T7_l[ iblockIdx.x83{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[2] ) * ( (( (( getMetaData(T0) )).logical_size ))[3] ) ), 4) ), blockDim.x) )}, iblockIdx.y86{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ), 1) )}, iUS87{1}, iS82{4}, ithreadIdx.x84{blockDim.x} ] ca_pos( 5 ) produce_pos( 5 )
   = T6_l[ iblockIdx.x99{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[2] ) * ( (( (( getMetaData(T0) )).logical_size ))[3] ) ), 4) ), blockDim.x) )}, iblockIdx.y102{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ), 1) )}, iUS103{1}, iS98{4}, ithreadIdx.x100{blockDim.x} ] ca_pos( 5 ) produce_pos( 3 )
   * T4_l[ bblockIdx.x91{( ceilDiv(( ceilDiv(( 1 * 1 ), 4) ), blockDim.x) )}, iblockIdx.y94{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * 1 ), 1) )}, iUS95{1}, bS90{4}, bthreadIdx.x92{blockDim.x} ] ca_pos( 3 ) produce_pos( 3 );
T12_l[ iblockIdx.x75{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[2] ) * ( (( (( getMetaData(T0) )).logical_size ))[3] ) ), 4) ), blockDim.x) )}, iblockIdx.y78{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ), 1) )}, iUS79{1}, iS74{4}, ithreadIdx.x76{blockDim.x} ] ca_pos( 3 ) produce_pos( 5 )
   = T11_l[ bblockIdx.x131{( ceilDiv(( ceilDiv(( 1 * 1 ), 4) ), blockDim.x) )}, iblockIdx.y134{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * 1 ), 1) )}, iUS135{1}, bS130{4}, bthreadIdx.x132{blockDim.x} ]
   * T7_l[ iblockIdx.x83{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[2] ) * ( (( (( getMetaData(T0) )).logical_size ))[3] ) ), 4) ), blockDim.x) )}, iblockIdx.y86{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ), 1) )}, iUS87{1}, iS82{4}, ithreadIdx.x84{blockDim.x} ] ca_pos( 5 ) produce_pos( 5 );
T8_g[ iblockIdx.x68{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[2] ) * ( (( (( getMetaData(T0) )).logical_size ))[3] ) ), 4) ), blockDim.x) )}, iblockIdx.y70{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ), 1) )}, iUS71{1}, iV67{4}, ithreadIdx.x69{blockDim.x} ] ca_pos( 3 ) produce_pos( 3 )
   = Set( T12_l[ iblockIdx.x75{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[2] ) * ( (( (( getMetaData(T0) )).logical_size ))[3] ) ), 4) ), blockDim.x) )}, iblockIdx.y78{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ), 1) )}, iUS79{1}, iS74{4}, ithreadIdx.x76{blockDim.x} ] ca_pos( 3 ) produce_pos( 5 ), cache_op=Streaming )

The RAW sync issue happens between T11 and T3.

T11_l[ bblockIdx.x131{( ceilDiv(( ceilDiv(( 1 * 1 ), 4) ), blockDim.x) )}, iblockIdx.y134{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * 1 ), 1) )}, iUS135{1}, bS130{4}, bthreadIdx.x132{blockDim.x} ]
T3_l[ iblockIdx.y151{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * 1 ), 1) )}, iUS152{1} ] ca_pos( 2 )

While it appears they use blockDim.y the same way, it's actually different because T3 is inlined into its consumers, and the actual extent of the domain 151 is not the same as its extent. Its actual extent is promoted to iblockIdx.y70{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ), 1) )}, and that's why the RAW sync analysis says there's a RAW dependency wrt blockIdx.y.

Here's the generated kernel when the assertion is commented out:


__global__ void nvfuser_pointwise_f0_c1_r0_g0(Tensor<float, 4, 4> T0, Tensor<float, 2, 2> T1, Tensor<float, 4, 4> T2, Tensor<float, 4, 4> T8) {
  NVFUSER_DEFINE_MAGIC_ZERO;
  float T11[1LL];
  T11[0LL] = 0LL;
  if ((((((nvfuser_index_t)blockIdx.y) >= 0LL) && (((nvfuser_index_t)blockIdx.y) < (ceilDiv((1LL * T0.logical_size[0LL]), 1LL)))) && (((nvfuser_index_t)blockIdx.y) < (1LL * T0.logical_size[0LL])))) {
    T11[0LL]
       = T2[(((nvfuser_index_t)blockIdx.y) * T2.alloc_stride[0LL])];
  } else {
    if ((((((nvfuser_index_t)blockIdx.y) >= 0LL) && (((nvfuser_index_t)blockIdx.y) < (1LL * T0.logical_size[0LL]))) && (((nvfuser_index_t)blockIdx.y) < (ceilDiv((1LL * T0.logical_size[0LL]), 1LL))))) {
      T11[0LL]
         = T2[(((nvfuser_index_t)blockIdx.y) * T2.alloc_stride[0LL])];
    }
  }
  if ((((((((nvfuser_index_t)blockIdx.y) >= 0LL) && ((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4LL) >= 0LL)) && (((nvfuser_index_t)blockIdx.y) < (ceilDiv((T0.logical_size[0LL] * T0.logical_size[1LL]), 1LL)))) && (((nvfuser_index_t)blockIdx.y) < (T0.logical_size[0LL] * T0.logical_size[1LL]))) && (((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4LL) + 3LL) < (T0.logical_size[2LL] * T0.logical_size[3LL])))) {
    float T10[1LL];
    T10[0LL] = 0LL;
    T10[0LL]
       = T1[(((((nvfuser_index_t)blockIdx.y) / T0.logical_size[1LL]) * T1.alloc_stride[0LL]) + ((((nvfuser_index_t)blockIdx.y) % T0.logical_size[1LL]) * T1.alloc_stride[1LL]))];
    float T5[1LL];
    T5[0LL]
       = T10[0LL];
    float T9[4LL];
    #pragma unroll
    for(nvfuser_index_t i0 = 0; i0 < 4LL; ++i0) {
      T9[i0] = 0LL;
    }
    NVFUSER_UPDATE_MAGIC_ZERO;
    #pragma unroll
    for(nvfuser_index_t i0 = 0; i0 < 4LL; ++i0) {
      T9[i0]
         = T0[(((((((nvfuser_index_t)blockIdx.y) / T0.logical_size[1LL]) * T0.alloc_stride[0LL]) + ((((nvfuser_index_t)blockIdx.y) % T0.logical_size[1LL]) * T0.alloc_stride[1LL])) + (T0.alloc_stride[2LL] * (((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4LL) + (i0 + nvfuser_zero)) / T0.logical_size[3LL]))) + (T0.alloc_stride[3LL] * (((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4LL) + (i0 + nvfuser_zero)) % T0.logical_size[3LL])))];
    }
    NVFUSER_UPDATE_MAGIC_ZERO;
    float T3[1LL];
    T3[0LL]
       = T11[0LL];
    float T4[1LL];
    T4[0LL]
       = T3[0LL];
    Array<float, 4LL, 4> T12;
    #pragma unroll
    for(nvfuser_index_t i1 = 0; i1 < 4LL; ++i1) {
      float T6[1LL];
      T6[0LL]
        = T9[i1]
        * T5[0LL];
      float T7[1LL];
      T7[0LL]
        = T6[0LL]
        * T4[0LL];
      T12[i1]
        = T11[0LL]
        * T7[0LL];
    }
    NVFUSER_UPDATE_MAGIC_ZERO;
    loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T8[((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4LL) + ((((nvfuser_index_t)blockIdx.y) * T0.logical_size[3LL]) * T0.logical_size[2LL]))], &T12[0LL]);
  } else {
    float T10[1LL];
    T10[0LL] = 0LL;
    if ((((((nvfuser_index_t)blockIdx.y) >= 0LL) && (((nvfuser_index_t)blockIdx.y) < (T0.logical_size[0LL] * T0.logical_size[1LL]))) && (((nvfuser_index_t)blockIdx.y) < (ceilDiv((T0.logical_size[0LL] * T0.logical_size[1LL]), 1LL))))) {
      T10[0LL]
         = T1[(((((nvfuser_index_t)blockIdx.y) / T0.logical_size[1LL]) * T1.alloc_stride[0LL]) + ((((nvfuser_index_t)blockIdx.y) % T0.logical_size[1LL]) * T1.alloc_stride[1LL]))];
    }
    float T5[1LL];
    T5[0LL]
       = T10[0LL];
    float T9[4LL];
    #pragma unroll
    for(nvfuser_index_t i0 = 0; i0 < 4LL; ++i0) {
      T9[i0] = 0LL;
    }
    NVFUSER_UPDATE_MAGIC_ZERO;
    #pragma unroll
    for(nvfuser_index_t i0 = 0; i0 < 4LL; ++i0) {
      if ((((((((nvfuser_index_t)blockIdx.y) >= 0LL) && (((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4LL) + 3LL) >= 0LL)) && (((nvfuser_index_t)blockIdx.y) < (T0.logical_size[0LL] * T0.logical_size[1LL]))) && (((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4LL) + 3LL) < (T0.logical_size[2LL] * T0.logical_size[3LL]))) && (((nvfuser_index_t)blockIdx.y) < (ceilDiv((T0.logical_size[0LL] * T0.logical_size[1LL]), 1LL))))) {
        T9[i0]
           = T0[(((((((nvfuser_index_t)blockIdx.y) / T0.logical_size[1LL]) * T0.alloc_stride[0LL]) + ((((nvfuser_index_t)blockIdx.y) % T0.logical_size[1LL]) * T0.alloc_stride[1LL])) + (T0.alloc_stride[2LL] * (((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4LL) + (i0 + nvfuser_zero)) / T0.logical_size[3LL]))) + (T0.alloc_stride[3LL] * (((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4LL) + (i0 + nvfuser_zero)) % T0.logical_size[3LL])))];
      }
    }
    NVFUSER_UPDATE_MAGIC_ZERO;
    float T3[1LL];
    T3[0LL]
       = T11[0LL];
    float T4[1LL];
    T4[0LL]
       = T3[0LL];
    Array<float, 4LL, 4> T12;
    #pragma unroll
    for(nvfuser_index_t i1 = 0; i1 < 4LL; ++i1) {
      float T6[1LL];
      T6[0LL]
        = T9[i1]
        * T5[0LL];
      float T7[1LL];
      T7[0LL]
        = T6[0LL]
        * T4[0LL];
      T12[i1]
        = T11[0LL]
        * T7[0LL];
    }
    NVFUSER_UPDATE_MAGIC_ZERO;
    if ((((((((nvfuser_index_t)blockIdx.y) >= 0LL) && ((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4LL) >= 0LL)) && (((nvfuser_index_t)blockIdx.y) < (ceilDiv((T0.logical_size[0LL] * T0.logical_size[1LL]), 1LL)))) && (((nvfuser_index_t)blockIdx.y) < (T0.logical_size[0LL] * T0.logical_size[1LL]))) && (((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4LL) + 3LL) < (T0.logical_size[2LL] * T0.logical_size[3LL])))) {
      loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T8[((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4LL) + ((((nvfuser_index_t)blockIdx.y) * T0.logical_size[3LL]) * T0.logical_size[2LL]))], &T12[0LL]);
    }
  }
}

Notice that the predicates of T11 and T3 are different, and T3 may use an invalid T11 value.

So, I think the sync analysis is actually correct here and the schedule is invalid.

jacobhinkle commented 8 months ago

Side note: I wonder why the T11 predicate is so strange in the generated kernel:

  float T11[1LL];
  T11[0LL] = 0LL;
  if ((((((nvfuser_index_t)blockIdx.y) >= 0LL) 
      && (((nvfuser_index_t)blockIdx.y) < (ceilDiv((1LL * T0.logical_size[0LL]), 1LL))))
      && (((nvfuser_index_t)blockIdx.y) < (1LL * T0.logical_size[0LL])))) {
    T11[0LL]
       = T2[(((nvfuser_index_t)blockIdx.y) * T2.alloc_stride[0LL])];
  } else {
    if ((((((nvfuser_index_t)blockIdx.y) >= 0LL)
        && (((nvfuser_index_t)blockIdx.y) < (1LL * T0.logical_size[0LL])))
        && (((nvfuser_index_t)blockIdx.y) < (ceilDiv((1LL * T0.logical_size[0LL]), 1LL))))) {
      T11[0LL]
         = T2[(((nvfuser_index_t)blockIdx.y) * T2.alloc_stride[0LL])];
    }
  }
  if ((((((((nvfuser_index_t)blockIdx.y) >= 0LL)
      && ((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4LL) >= 0LL))
      && (((nvfuser_index_t)blockIdx.y) < (ceilDiv((T0.logical_size[0LL] * T0.logical_size[1LL]), 1LL))))
      && (((nvfuser_index_t)blockIdx.y) < (T0.logical_size[0LL] * T0.logical_size[1LL])))
      && (((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4LL) + 3LL) < (T0.logical_size[2LL] * T0.logical_size[3LL])))) {
    float T3[1LL];
    T3[0LL]
       = T11[0LL];

The T3 predicate could also be simplified using BID >= 0, TID >= 0, BDIM >= 0. and ceilDiv(x, 1) == x:

  float T11[1LL];
  T11[0LL] = 0LL;
  if ((nvfuser_index_t)blockIdx.y < T0.logical_size[0LL]) {
    T11[0LL]
       = T2[(((nvfuser_index_t)blockIdx.y) * T2.alloc_stride[0LL])];
  }
  if ((((nvfuser_index_t)blockIdx.y) < (T0.logical_size[0LL] * T0.logical_size[1LL])))
      && (((((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.x)) + ((nvfuser_index_t)threadIdx.x)) * 4LL) + 3LL) < (T0.logical_size[2LL] * T0.logical_size[3LL])))) {
    float T3[1LL];
    T3[0LL]
       = T11[0LL];
naoyam commented 8 months ago

It's just because I disabled index simplification and hoisting as that makes it a little cumbersome to see what predicates are actually used.

naoyam commented 8 months ago

I looked into this more deeply, and it turned out this is in fact an interesting inlining problem. The problem starts with the missing inlining of T11:

Producer: T11_l[ bblockIdx.x131{( ceilDiv(( ceilDiv(( 1 * 1 ), 4) ), blockDim.x) )}, iblockIdx.y134{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * 1 ), 1) )}, iUS135{1}, bS130{4}, bthreadIdx.x132{blockDim.x} ]
Consumer: T3_l[ iblockIdx.y151{( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * 1 ), 1) )}, iUS152{1} ] ca_pos( 2 )

This is because of ID 131 of T11, for which there's no corresponding domain in T3. It's just a broadcast domain, so usually what happens is we move such a domain to an inner position such that ID 134 is located at the outermost position. However, in this case, there's another consumer of T11, T12, which does have all matching domains, so the pointwise scheduler does not reorder them. In fact, it should not reorder for performance reasons.

Since the inlineMost logic only looks at the immediate consumers of T11, i.e., T3 and T12, and takes the minimum CA position possible, it is determined that no inlining is possible due to T3. However, that isn't actually the case since T3 is inlined into T4, which does have all the matching domains. It should be legal to have the CA position of 2 or larger.

More specifically, if we only look at T3, it appears that there's only loops for the two leaf domains of T3. However, since it's inlined into T4, which has additional domains at the outer positions of those domains matching with T3, the actual loop-nest does have a loop for each of the five leaf domains of T4, which includes those that match with the T11 domains.

I think what this all means is that MaxPosCalculator needs some extention that instead of just looking at producer-consumer pairs, it would need to understand the actual loop-nest where a given producer would be inlined. Note that there's a dependency of computing the max position from consumers to producers since the inlining position of a tensor would depend on the inlining position of each of its consumers.

(I'll update this comment with some illustration later).