NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

Additional IDs are lost when mutating a TensorDomain #3409

Closed jacobhinkle closed 6 days ago

jacobhinkle commented 6 days ago

When we create a Broadcast IterDomain using tv->broadcast(pos), it does inserts the new ID into the loop domain of tv, as well as into another domain called tv->additionalIDs(). This additional IDs domain is used to find traversal paths in order to define all ID expressions when building ID graphs. We must ensure that these domains are preserved whenever we clone or mutate a TensorDomain, otherwise this can cause issues when building an ID graph with the new Fusion.

We currently do not preserve the additional IDs when mutating, as seen in this repro:

TEST_F(NVFuserTest, OptOutMutatorAdditionalBroadcastID) {
  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
  Fusion* fusion = fusion_ptr.get();
  FusionGuard fg(fusion);

  auto tv0 = makeSymbolicTensor(1);
  fusion->addInput(tv0);

  auto tv1 = exp(tv0);

  fusion->addOutput(tv1);

  // We add a broadcast domain bS2{1}. This adds the new Broadcast ID to tv1->domain()->additionalIDs()
  // logical: [ iS1{i0} ]
  // loop: [ iS1{i0}, bS2{1} ]
  // additional IDs: [ bS2{1} ]
  tv1->broadcast(1);
  EXPECT_FALSE(tv1->domain()->additionalIDs().empty());

  // After this split we have
  // logical: [ iS1{i0} ]
  // loop: [ iS1{i0}, bS3{1}, bS4{2} ]
  // additional IDs: [ bS2{1} ]
  tv1->split(1, 2);
  EXPECT_FALSE(tv1->domain()->additionalIDs().empty());

  // Now register a mutation that will alter some IDs in the domain
  OptOutMutator mut;
  mut.registerMutation(tv1->axis(0)->extent(), IrBuilder::create<Val>(DataType::Index));
  TensorDomain* old_tensor_domain = tv1->domain();
  auto all_stmts = StmtSort::getStmts(
      fusion,
      /*traverse_members*/ true,
      /*traverse_attributes*/ true,
      /*traverse_siblings*/ true);
  for (auto stmt : all_stmts) {
    mut.dispatchMutate(stmt);
  }
  EXPECT_TRUE(tv1->domain() != old_tensor_domain) << "Mutation did not change the TensorDomain";

  EXPECT_FALSE(tv1->domain()->additionalIDs().empty())<< "Mutation did not preserve additional IDs";
}

I noticed this in a very similar case when working on #3406 .