Open naoyam opened 2 years ago
blockBroadcast
cannot support this type of use case. I would expect parallel validation should fail in this case instead of predicate check complaining here. This is a change in the parallelization strategy where T3 has a completely different parallel scheme than T2. The only way this can be supported is when T2 is in shared memory, then we should do the right thing, which we almost do when that check is removed:
__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T3) {
alignas(4) extern __shared__ char array[];
unsigned offset = 0;
void* shared_mem = array;
offset += ((blockDim.x * blockDim.y * blockDim.z) * sizeof(float));
float T1[T0.size[0]];
#pragma unroll 1
for(nvfuser_index_t i16 = 0; i16 < T0.size[0]; ++i16) {
T1[i16] = 0;
}
offset = alignBufferSize(offset,4);
float* T2 = reinterpret_cast<float*>(array + offset);
offset += ((T0.size[0] * 1) * sizeof(float));
#pragma unroll 1
for(nvfuser_index_t i13 = 0; i13 < T0.size[0]; ++i13) {
blockReduce<true, false, false>(
T1[i13],
T0[(i13 * T0.stride[0]) + (((nvfuser_index_t)threadIdx.x) * T0.stride[1])],
[](float &a, float b) { a = a + b; },
threadIdx,
blockDim,
static_cast<float*>(shared_mem),
(((nvfuser_index_t)threadIdx.x) < T0.size[1]),
float(0));
}
if (((((nvfuser_index_t)threadIdx.x) < (T0.size[0] * 1)) && (((nvfuser_index_t)threadIdx.x) == 0))) {
T2[((nvfuser_index_t)threadIdx.x)]
= T1[(((nvfuser_index_t)threadIdx.x) / 1)];
}
if ((((((nvfuser_index_t)threadIdx.x) < (T0.size[0] * T0.size[1])) && (((nvfuser_index_t)threadIdx.x) < (T0.size[0] * T0.size[1]))) && (((nvfuser_index_t)threadIdx.x) == 0))) {
T3[(((nvfuser_index_t)threadIdx.x) * 1)]
= T2[(((nvfuser_index_t)threadIdx.x) / T0.size[1])]
+ T0[((((nvfuser_index_t)threadIdx.x) / T0.size[1]) * T0.stride[0]) + ((((nvfuser_index_t)threadIdx.x) % T0.size[1]) * T0.stride[1])];
}
__barrier_sync(0);
}
The issues here are:
((nvfuser_index_t)threadIdx.x) == 0
shouldn't exist on the T2 and T3 expressions.This bug isn't high priority as we don't do parallelization schemes like this in practice, but it's not great that this produces such incorrect code.
Disabled test case is added for this issue in https://github.com/csarofeen/pytorch/pull/1412
I think another thing we may need to think about is that once a broadcast domain is merged with a non-broadcast domain, it becomes a non-broadcast domain, i.e., IterDomain::isBroadcast()
returns false. We often just look at leaf IDs and see if there's any broadcast (e.g., TensorDomain::hasBroadcast()
), which would just return false even when the root domain has a broadcast.
What's not really clear to me is what the merged domain of a broadcast and a non-broadcast domain would really mean. It doesn't matter when the broadcast is trivial, e.g., it's on shared memory, not concretized or not predicated. Otherwise, I'm not sure what the right semantics would be. If we don't care as that won't be necessary, we might just want to disable merging of broadcast and non-broadcast domains when the broadcast is not trivial.
I'm not sure what the correct behavior should be, but this fusion currently fails at the validation in the beginning of lowering.
The issue is not (just) the validation. Even if the validation is skipped, invalid code is generated:
Notice that there's no
blockBroadcast
. This is not due to the concrete broadcast domain PR (#1412). It's because merging broadcast and non-broadcast domains results in a non-broadcast domain, soT2
has no broadcast, and no broadcast runtime call is generated.This problem only happens when a broadcast domain requires an actual parallel broadcast, which means the tensor must be predicated with the same parallel type as used on the broadcast domain.