When we create a Broadcast IterDomain using tv->broadcast(pos), it does inserts the new ID into the loop domain of tv, as well as into another domain called tv->additionalIDs(). This additional IDs domain is used to find traversal paths in order to define all ID expressions when building ID graphs. We must ensure that these domains are preserved whenever we clone or mutate a TensorDomain, otherwise this can cause issues when building an ID graph with the new Fusion.
We currently do not preserve the additional IDs when mutating, as seen in this repro:
TEST_F(NVFuserTest, OptOutMutatorAdditionalBroadcastID) {
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
Fusion* fusion = fusion_ptr.get();
FusionGuard fg(fusion);
auto tv0 = makeSymbolicTensor(1);
fusion->addInput(tv0);
auto tv1 = exp(tv0);
fusion->addOutput(tv1);
// We add a broadcast domain bS2{1}. This adds the new Broadcast ID to tv1->domain()->additionalIDs()
// logical: [ iS1{i0} ]
// loop: [ iS1{i0}, bS2{1} ]
// additional IDs: [ bS2{1} ]
tv1->broadcast(1);
EXPECT_FALSE(tv1->domain()->additionalIDs().empty());
// After this split we have
// logical: [ iS1{i0} ]
// loop: [ iS1{i0}, bS3{1}, bS4{2} ]
// additional IDs: [ bS2{1} ]
tv1->split(1, 2);
EXPECT_FALSE(tv1->domain()->additionalIDs().empty());
// Now register a mutation that will alter some IDs in the domain
OptOutMutator mut;
mut.registerMutation(tv1->axis(0)->extent(), IrBuilder::create<Val>(DataType::Index));
TensorDomain* old_tensor_domain = tv1->domain();
auto all_stmts = StmtSort::getStmts(
fusion,
/*traverse_members*/ true,
/*traverse_attributes*/ true,
/*traverse_siblings*/ true);
for (auto stmt : all_stmts) {
mut.dispatchMutate(stmt);
}
EXPECT_TRUE(tv1->domain() != old_tensor_domain) << "Mutation did not change the TensorDomain";
EXPECT_FALSE(tv1->domain()->additionalIDs().empty())<< "Mutation did not preserve additional IDs";
}
I noticed this in a very similar case when working on #3406 .
When we create a Broadcast IterDomain using
tv->broadcast(pos)
, it does inserts the new ID into the loop domain oftv
, as well as into another domain calledtv->additionalIDs()
. This additional IDs domain is used to find traversal paths in order to define all ID expressions when building ID graphs. We must ensure that these domains are preserved whenever we clone or mutate a TensorDomain, otherwise this can cause issues when building an ID graph with the new Fusion.We currently do not preserve the additional IDs when mutating, as seen in this repro:
I noticed this in a very similar case when working on #3406 .