NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
235 stars 43 forks source link

Cannot define an `MmaOp` where batch dimension is `Broadcast` #2273

Open jacobhinkle opened 1 month ago

jacobhinkle commented 1 month ago

The following test fails when trying to create an MmaOp

// Single batch dimension which is broadcast
TEST_F(GPUTTensorCoreTest, FusionAmpereBroadcastBatchMatmul_CUDA) {
  auto layout = MmaLayout::TN;

  Fusion fusion;
  FusionGuard fg(&fusion);

  auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);

  auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
  auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);

  fusion.addInput(tv0);
  fusion.addInput(tv1);

  tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
  tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
  auto tv2 = fusedMultiplySum(
      broadcast(tv0, {true, false, false, false}),
      broadcast(tv1, {true, false, false, false}),
      {-1});
/*
C++ exception with description "details.bcasts.empty() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/ir/utils.cpp":1
268, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. MmaOp output: has broadcast domains.                                                                                                                 
Exception raised from operator() at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:1268 (most recent call first):
*/

  fusion.addOutput(tv2);
}

This caused the failure of https://github.com/NVIDIA/Fuser/blob/5228f89768fc47b6b542ecdbf19fae389c0c1e1f/tests/cpp/test_combine_mul_sum.cpp#L121 which is why that test currently checks that we cannot translate that case. However, I think that case should be covered and we should instead fix the MmaOp ctor to not balk at such cases.

jacobhinkle commented 1 month ago

Relevant comment from the PR introducing this check: https://github.com/NVIDIA/Fuser/pull/131/files#r1164511926. It seems we do plan to support this broadcast batch dims, but getMmaOpDetails and getInputLayout don't currently support it. This might change with #2272 since that uses IdModel and allocation domain to determine layout instead of pattern matching.