Open jjsjann123 opened 7 months ago
Is this a real error or just a validation error? In other words, should we avoid mutating the allocation domain of the output of the matmul test?
I think it's just a validation error in @Priya2698 's example in #1775:
auto tv0 = makeConcreteTensor(a_shape, DataType::Half);
auto tv1 = makeConcreteTensor(b_shape, DataType::Half);
auto tv0b = broadcast(tv0, {false, false, true}); // [M, K, 1]
auto tv1b = broadcast(tv1, {true, false, false}); // [1, K, N]
auto tv2 = fusedMultiplySum(tv0b, tv1b, {1});
auto tv3 = makeConcreteTensor({m}, DataType::Half);
auto tv4 = castOp(DataType::Float, tv3);
auto tv5 = biasEpilogue(tv2, tv4);
auto tv6 = castOp(DataType::Half, tv5);
fusion->addInput(tv0);
fusion->addInput(tv1);
fusion->addInput(tv3);
fusion->addOutput(tv6);
So the chain of propagation follows tv3 -> tv4 -> tv5 -> tv6
. After updating allocation order of output here, we hit the assert.
Two things I think worth a discussion:
I think it's just a validation error in @Priya2698 's example in #1775:
auto tv0 = makeConcreteTensor(a_shape, DataType::Half); auto tv1 = makeConcreteTensor(b_shape, DataType::Half); auto tv0b = broadcast(tv0, {false, false, true}); // [M, K, 1] auto tv1b = broadcast(tv1, {true, false, false}); // [1, K, N] auto tv2 = fusedMultiplySum(tv0b, tv1b, {1}); auto tv3 = makeConcreteTensor({m}, DataType::Half); auto tv4 = castOp(DataType::Float, tv3); auto tv5 = biasEpilogue(tv2, tv4); auto tv6 = castOp(DataType::Half, tv5); fusion->addInput(tv0); fusion->addInput(tv1); fusion->addInput(tv3); fusion->addOutput(tv6);
So the chain of propagation follows
tv3 -> tv4 -> tv5 -> tv6
. After updating allocation order of output here, we hit the assert.Two things I think worth a discussion:
- Should we still be able to specify output allocation order?! @Priya2698. If "someone" specified output allocation, should we still respect that? We probably shouldn't assert, so I think this is a legit issue we should fix on validation.
Do you mean if the user specifies a output allocation order?
Another question: What happens when some operator in the fusion is not handled yet. You mentioned that fusedMultiplySum
is not supported in the propagation rules. We are hitting this assert since we do not have any allocation order propagated for tv2
and end up updating the allocation order based on tv3
.
Do you mean if the user specifies a output allocation order?
yes. If the user specifies an output allocation order, we need to respect that. Admittedly this is a low priority for matmul case right now.
Regarding the allocation domain propagation, should it be considered the responsibility for the pass to figure out the right output allocation for the operation?
I think the propagation rule should try to help. But if fusedMultiplySum
requires output to be in certain allocation order, that needs to be enforced and represented before hitting the optimization passes.
i.e. in this example, it's more than just the fusedMultiplySum
output, but its epilogue fusion as well.
This is going to be tricky for the pass to handle. (not that it's easier anywhere else :) )
Allocation order inference pass has mutated allocation domain on output tensor in matmul tests. I accidentally missed the broken CI since two propagation rules going in in parallel weren't combined to go through tests. :cry:
I'm temporarily disabling the optimization pass in matmul tests for the sake of a green CI. Let's discuss how we should avoid these (hint, maybe we should enforce output stride order in matmul tests)