Closed naoyam closed 2 years ago
Actually, it's not just reduction. This also fails at the validation due to the predicate of tv2
being eliminated:
TEST_F(NVFuserTest, FusionPredicateElimination3_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
auto tv0 = makeConcreteTensor({2, 3});
fusion.addInput(tv0);
auto tv1 = add(tv0, IrBuilder::create<Double>(1));
auto tv2 = add(tv1, IrBuilder::create<Double>(1));
auto tv3 = add(tv2, IrBuilder::create<Double>(1));
fusion.addOutput(tv3);
tv2->split(1, 2);
tv1->split(1, 2);
fusion.printMath();
fusion.printKernel();
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({2, 3}, options);
FusionExecutor fe;
fe.compileFusion(&fusion, {t0});
auto cg_outputs = fe.runFusion({t0});
auto ref = t0 + 3;
testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__);
}
It seems it's been broken ever since we changed how the index is linearlized.
In your first example it's the T1 = T0
one that needs to be predicated right?
Had to revisit, think I understand the issue now
Okay, wow, this took me way too much time to understand. In the second example (since it's simpler) it's the combination of:
auto& T2 = T1;
#pragma unroll
for(nvfuser_index_t i25 = 0; i25 < 2; ++i25) {
#pragma unroll
for(nvfuser_index_t i26 = 0; i26 < (ceilDiv(3, 2)); ++i26) {
#pragma unroll
for(nvfuser_index_t i27 = 0; i27 < 2; ++i27) {
T2[(i25 * 3) + ((i26 * 2) + i27)]
= T1[(i25 * 3) + ((i26 * 2) + i27)]
+ (float) 1;
}
}
}
and we hit overlapping indices at [0, 2, 2] and [1, 0, 0] so:
T2[(i25 * 3) + ((i26 * 2) + i27)]
= T1[(i25 * 3) + ((i26 * 2) + i27)]
+ (float) 1;
Is run twice on that point, which wouldn't be an issue, except it's aliased memory so it's effectively running an in place computation twice.
I think there's 3 options here, which you decided to go with the first option in https://github.com/csarofeen/pytorch/pull/1582:
I think doing 1 in the short term is a reasonable approach to quickly patch the issue, however I don't know all the trade offs here. 2 was really annoying because of outer splits which meant we would have needed even more predicates in those instances, so maybe this is a bad idea. 3 could be good, but then register pressure may increase.
@csarofeen Yes, #1582 should disable eliminating required predicates. It's more conservative, but the impact would be minimal in practice. The problem only happens when an IterDomain is split but not parallelized. If parallelized, the domain becomes "zero merged in", so indexing computes the extent by multiplying the extents of child domains. After all, we haven't seen any actual error since the indexing was changed.
This logic seems wrong (blaming myself):
https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/lower_predicate.cpp#L343-L348
Repro:
Validation fails:
tv2's predicate is eliminated, which is wrong: