Closed shino16 closed 1 week ago
Thank you for the bug report. cc @jjsjann123
The issue seems to be that the output T4
is aliasing T2
, and they are both broadcast. The generated kernel looks like this:
__global__ void nvfuser_pointwise_f0_c1_r0_g0(Tensor<float, 1, 1> T0, Tensor<float, 1, 1> T1, Tensor<float, 0, 0> T2, Tensor<float, 0, 0> T11, Tensor<float, 1, 1> T7) {
nvfuser_index_t i0;
i0 = 4 * ((nvfuser_index_t)threadIdx.x);
nvfuser_index_t i1;
i1 = 512 * ((nvfuser_index_t)blockIdx.x);
nvfuser_index_t i2;
i2 = i0 + i1;
bool b3;
b3 = ((3 + i0) + i1) < T0.logical_size[0LL];
float T10[1];
T10[0] = 0;
T10[0]
= T2[0];
float T3[1];
T3[0]
= T10[0]
+ (float) 1.00000000000000000e+00;
float T12[1];
T12[0]
= T3[0];
if (((((nvfuser_index_t)blockIdx.x) == 0) && (((nvfuser_index_t)threadIdx.x) == 0))) {
T11[0]
= T12[0];
}
float T4[1];
T4[0]
= T3[0];
float T5[1];
T5[0]
= T4[0];
if ((((i0 + 3) + i1) < T0.logical_size[0LL])) {
Array<float, 4, 4> T9;
T9.set(float(0));
loadGlobalToLocal<float, /*vec_size=*/4, /*is_volatile=*/false, CacheOp::Streaming>(&T9[0], &T1[i2]);
Array<float, 4, 4> T8;
T8.set(float(0));
loadGlobalToLocal<float, /*vec_size=*/4, /*is_volatile=*/false, CacheOp::Streaming>(&T8[0], &T0[i2]);
// Alias Allocation - register
auto& T13 = T9;
#pragma unroll
for(nvfuser_index_t i4 = 0; i4 < 4; ++i4) {
float T6[1];
T6[0]
= T9[i4]
* T5[0];
T13[i4]
= T8[i4]
+ T6[0];
}
loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T7[i2], &T13[0]);
} else {
Array<float, 4, 4> T9;
T9.set(float(0));
if (b3) {
loadGlobalToLocal<float, /*vec_size=*/4, /*is_volatile=*/false, CacheOp::Streaming>(&T9[0], &T1[i2]);
}
Array<float, 4, 4> T8;
T8.set(float(0));
if (b3) {
loadGlobalToLocal<float, /*vec_size=*/4, /*is_volatile=*/false, CacheOp::Streaming>(&T8[0], &T0[i2]);
}
// Alias Allocation - register
auto& T13 = T9;
#pragma unroll
for(nvfuser_index_t i4 = 0; i4 < 4; ++i4) {
float T6[1];
T6[0]
= T9[i4]
* T5[0];
T13[i4]
= T8[i4]
+ T6[0];
}
if (b3) {
loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T7[i2], &T13[0]);
}
}
}
I think to properly compute something like this we would need a couple semaphores and only write the value back to the aliased input in the last block after all other threads have read the original value. Until then we should we throw an error if asked to alias broadcasted inputs?
Good spot! @jacobhinkle
we can reject copy_
in thunder if we are seeing a size(1) tensor I guess?
So the logic is that, if we do not have broadcast on the aliased input/output, we won't have multiple threads mapping to the same aliased buffer, hence there won't be any potential race in the RW sequence.
float T10[1];
T10[0] = 0;
T10[0]
= T2[0];
float T3[1];
T3[0]
= T10[0]
+ (float) 1.00000000000000000e+00;
float T12[1];
T12[0]
= T3[0];
if (((((nvfuser_index_t)blockIdx.x) == 0) && (((nvfuser_index_t)threadIdx.x) == 0))) {
T11[0]
= T12[0];
}
How do we define broadcast here? that doesn't feel like a local decision we can make in integration, but rather a check that should run after segmentation?
i.e. looking the the batch_norm example here: https://github.com/Lightning-AI/lightning-thunder/blob/2a7602d69836a5c5671e2a898dcb17aedcb3adba/thunder/torch/__init__.py#L3536-L3550
mean
is being broadcasted and used to compute out
. But on the branch to update prims.copy_(new_running_mean, running_mean)
there's no broadcast. So we need to know how the segmentation looks like in order to assert/throw?
@jacobhinkle can you repost with index hoisting disabled.
Also, can you point to the problematic area in the kernel? Having trouble how this is being parallellized, which is making it difficult for me to understand the problem.
Could you also reduce the repro, I think you can simply remove T9 and output T8 and still repro.
Reduced repro with no index_hoist
:
__global__ void nvfuser_pointwise_f0_c1_r0_g0(Tensor<float, 1, 1> T0, Tensor<float, 1, 1> T1, Tensor<float, 0, 0> T2, Tensor<float, 0, 0> T9, Tensor<float, 1, 1> T6) {
float T8[1];
T8[0] = 0;
T8[0]
= T2[0];
float T3[1];
T3[0]
= T8[0]
+ (float) 1.00000000000000000e+00;
float T10[1];
T10[0]
= T3[0];
if (((((nvfuser_index_t)blockIdx.x) == 0) && (((nvfuser_index_t)threadIdx.x) == 0))) {
T9[0]
= T10[0];
}
float T4[1];
T4[0]
= T3[0];
float T5[1];
T5[0]
= T4[0];
if (((((4 * ((nvfuser_index_t)threadIdx.x)) + 3) + (512 * ((nvfuser_index_t)blockIdx.x))) < T1.logical_size[0LL])) {
Array<float, 4, 4> T7;
T7.set(float(0));
loadGlobalToLocal<float, /*vec_size=*/4, /*is_volatile=*/false, CacheOp::Streaming>(&T7[0], &T1[((4 * ((nvfuser_index_t)threadIdx.x)) + (512 * ((nvfuser_index_t)blockIdx.x)))]);
Array<float, 4, 4> T11;
#pragma unroll
for(nvfuser_index_t i0 = 0; i0 < 4; ++i0) {
T11[i0]
= T7[i0]
* T5[0];
}
loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T6[((4 * ((nvfuser_index_t)threadIdx.x)) + (512 * ((nvfuser_index_t)blockIdx.x)))], &T11[0]);
} else {
Array<float, 4, 4> T7;
T7.set(float(0));
if ((((4 * ((nvfuser_index_t)threadIdx.x)) + (512 * ((nvfuser_index_t)blockIdx.x))) < T1.logical_size[0LL])) {
loadGlobalToLocal<float, /*vec_size=*/4, /*is_volatile=*/false, CacheOp::Streaming>(&T7[0], &T1[((4 * ((nvfuser_index_t)threadIdx.x)) + (512 * ((nvfuser_index_t)blockIdx.x)))]);
}
Array<float, 4, 4> T11;
#pragma unroll
for(nvfuser_index_t i0 = 0; i0 < 4; ++i0) {
T11[i0]
= T7[i0]
* T5[0];
}
if ((((4 * ((nvfuser_index_t)threadIdx.x)) + (512 * ((nvfuser_index_t)blockIdx.x))) < T1.logical_size[0LL])) {
loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T6[((4 * ((nvfuser_index_t)threadIdx.x)) + (512 * ((nvfuser_index_t)blockIdx.x)))], &T11[0]);
}
}
}
======================================
Arguments for kernel0:
Inputs:
Tensor(sizes=[4194304], stride=[1], dtype=float, device=cuda:0, data_ptr=0x72c4f4600000)
Tensor(sizes=[4194304], stride=[1], dtype=float, device=cuda:0, data_ptr=0x72c4f8a00000)
Tensor(sizes=[], stride=[], dtype=float, device=cuda:0, data_ptr=0x72c4f5800000)
Outputs:
Float [] (strides = [], address = 0x72c4f5800000)
Float [4194304] (strides = [1], address = 0x72c4e6000000)
Intermediate global buffers:
This does reproduce the error. Note the portion:
if (((((nvfuser_index_t)blockIdx.x) == 0) && (((nvfuser_index_t)threadIdx.x) == 0))) {
T9[0]
= T10[0];
}
Also note in the kernel_args
that the T9
pointer is equal to that of T2
, so the load of T8
is in a race with this write to T9
.
For reference, here is the scheduled IR:
%kernel {
T8_l[ ]
= Set( T2_g[ ], cache_op=AllLevels )
T3_l[ ]
= T8_l[ ]
+ double(1);
T10_l[ ]
= Set( T3_l[ ], cache_op=Streaming )
T9_g[ ]
= Set( T10_l[ ], cache_op=Streaming )
T7_l[ iblockIdx.x37{( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), 1) ), 128) )}, iUS36{1}, iV34{4}, ithreadIdx.x38{128} ] ca_pos( 2 )
= Set( T1_g[ iS43{( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), 1) ), 128) )}, iS42{1}, iS40{4}, iS44{128} ], cache_op=Streaming )
T4_l[ bblockIdx.x31{( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), 1) ), 128) )}, bUS30{1}, bS28{4}, bthreadIdx.x32{128} ]
= broadcast( T3_l[ ] )
T5_l[ bblockIdx.x25{( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), 1) ), 128) ) ex ( ceilDiv(( ceilDiv(( ceilDiv(4194304, 4) ), 1) ), 128) )}, bUS24{1}, bS22{4}, bthreadIdx.x26{128} ] = expand( T4_l[ bblockIdx.x31{( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), 1) ), 128) )}, bUS30{1}, bS28{4}, bthreadIdx.x32{128} ], {4194304} )
T11_l[ iblockIdx.x19{( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), 1) ), 128) )}, iUS18{1}, iS16{4}, ithreadIdx.x20{128} ] ca_pos( 2 ) produce_pos( 2 )
= T7_l[ iblockIdx.x37{( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), 1) ), 128) )}, iUS36{1}, iV34{4}, ithreadIdx.x38{128} ] ca_pos( 2 )
* T5_l[ bblockIdx.x25{( ceilDiv(( ceilDiv(( ceilDiv(1, 4) ), 1) ), 128) ) ex ( ceilDiv(( ceilDiv(( ceilDiv(4194304, 4) ), 1) ), 128) )}, bUS24{1}, bS22{4}, bthreadIdx.x26{128} ];
T6_g[ iblockIdx.x13{( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), 1) ), 128) )}, iUS12{1}, iV10{4}, ithreadIdx.x14{128} ] ca_pos( 2 ) produce_pos( 2 )
= Set( T11_l[ iblockIdx.x19{( ceilDiv(( ceilDiv(( ceilDiv(i2, 4) ), 1) ), 128) )}, iUS18{1}, iS16{4}, ithreadIdx.x20{128} ] ca_pos( 2 ) produce_pos( 2 ), cache_op=Streaming )
Mildly interesting: compute-sanitizer --tool=racecheck
does not find any hazards in this repro.
Self assigned. Note for myself that i should also take a look at the thunder issue linked above.
Resolved by PR #2999
The following script gave numerically incorrect values on a small part of the array.
I generated its fusion definition by calling
FusionDefinition.getReproErrorString
manually.