NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
257 stars 51 forks source link

allocation order inference pass broken matmul scheduler tests #1810

Open jjsjann123 opened 7 months ago

jjsjann123 commented 7 months ago

Allocation order inference pass has mutated allocation domain on output tensor in matmul tests. I accidentally missed the broken CI since two propagation rules going in in parallel weren't combined to go through tests. :cry:

I'm temporarily disabling the optimization pass in matmul tests for the sake of a green CI. Let's discuss how we should avoid these (hint, maybe we should enforce output stride order in matmul tests)

naoyam commented 7 months ago

Is this a real error or just a validation error? In other words, should we avoid mutating the allocation domain of the output of the matmul test?

jjsjann123 commented 7 months ago

I think it's just a validation error in @Priya2698 's example in #1775:

  auto tv0 = makeConcreteTensor(a_shape, DataType::Half);
  auto tv1 = makeConcreteTensor(b_shape, DataType::Half);
  auto tv0b = broadcast(tv0, {false, false, true}); // [M, K, 1]
  auto tv1b = broadcast(tv1, {true, false, false}); // [1, K, N]
  auto tv2 = fusedMultiplySum(tv0b, tv1b, {1});
  auto tv3 = makeConcreteTensor({m}, DataType::Half);
  auto tv4 = castOp(DataType::Float, tv3);
  auto tv5 = biasEpilogue(tv2, tv4);
  auto tv6 = castOp(DataType::Half, tv5);

  fusion->addInput(tv0);
  fusion->addInput(tv1);
  fusion->addInput(tv3);
  fusion->addOutput(tv6);

So the chain of propagation follows tv3 -> tv4 -> tv5 -> tv6. After updating allocation order of output here, we hit the assert.

Two things I think worth a discussion:

  1. Should we still be able to specify output allocation order?! @Priya2698. If "someone" specified output allocation, should we still respect that? We probably shouldn't assert, so I think this is a legit issue we should fix on validation.
  2. I should probably also follow @zasdfgbnm's suggestion on propagation rule for broadcast here :laughing: . That would at least not exposing issues with the test above.
Priya2698 commented 7 months ago

I think it's just a validation error in @Priya2698 's example in #1775:

  auto tv0 = makeConcreteTensor(a_shape, DataType::Half);
  auto tv1 = makeConcreteTensor(b_shape, DataType::Half);
  auto tv0b = broadcast(tv0, {false, false, true}); // [M, K, 1]
  auto tv1b = broadcast(tv1, {true, false, false}); // [1, K, N]
  auto tv2 = fusedMultiplySum(tv0b, tv1b, {1});
  auto tv3 = makeConcreteTensor({m}, DataType::Half);
  auto tv4 = castOp(DataType::Float, tv3);
  auto tv5 = biasEpilogue(tv2, tv4);
  auto tv6 = castOp(DataType::Half, tv5);

  fusion->addInput(tv0);
  fusion->addInput(tv1);
  fusion->addInput(tv3);
  fusion->addOutput(tv6);

So the chain of propagation follows tv3 -> tv4 -> tv5 -> tv6. After updating allocation order of output here, we hit the assert.

Two things I think worth a discussion:

  1. Should we still be able to specify output allocation order?! @Priya2698. If "someone" specified output allocation, should we still respect that? We probably shouldn't assert, so I think this is a legit issue we should fix on validation.

Do you mean if the user specifies a output allocation order? Another question: What happens when some operator in the fusion is not handled yet. You mentioned that fusedMultiplySum is not supported in the propagation rules. We are hitting this assert since we do not have any allocation order propagated for tv2 and end up updating the allocation order based on tv3.

jjsjann123 commented 7 months ago

Do you mean if the user specifies a output allocation order?

yes. If the user specifies an output allocation order, we need to respect that. Admittedly this is a low priority for matmul case right now.

Regarding the allocation domain propagation, should it be considered the responsibility for the pass to figure out the right output allocation for the operation? I think the propagation rule should try to help. But if fusedMultiplySum requires output to be in certain allocation order, that needs to be enforced and represented before hitting the optimization passes.

i.e. in this example, it's more than just the fusedMultiplySum output, but its epilogue fusion as well. This is going to be tricky for the pass to handle. (not that it's easier anywhere else :) )