NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
240 stars 44 forks source link

Allow reductions in matmul epilogues #2213

Open jacobhinkle opened 2 months ago

jacobhinkle commented 2 months ago

We currently support pointwise epilogues. However, there is an opportunity for us to fuse reductions into our epilogues. For example, adding cross-entropy loss function would allow us to reduce total IO considerably (ignoring the possible need to save matrix outputs for the backwards pass). Pytorch has an open issue requesting such a fused op https://github.com/pytorch/pytorch/issues/124480.

I tried the following straightforward approach:

      Fusion fusion;
      FusionGuard fg(&fusion);

      auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);

      auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
      auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);

      fusion.addInput(tv0);
      fusion.addInput(tv1);

      tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
      tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
      auto tv2 = fusedMultiplySum(tv0, tv1, {-1});

      auto tv3 = makeContigConcreteTensor({M}, DataType::Int32);
      fusion.addInput(tv3);
      auto is_true_label =
          eq(broadcast(tv3, {false, true}),
             broadcast(
                 iota(
                     IrBuilder::create<Val>(N),
                     /*start=*/nullptr,
                     /*step=*/nullptr,
                     DataType::Int32),
                 {true, false}));
      // allreduce log sum exp in N dimension
      auto m = max(tv2, {1});
      auto m_bcast = broadcast(m, {false, true});
      auto e = exp(sub(tv2, m_bcast));
      auto sumexp = sum(e, {1});
      auto lse_noadj = log(sumexp);
      auto lse_nobcast = add(lse_noadj, m);
      auto lse = broadcast(lse_nobcast, {false, true});

      // xent = lse - y_i*onehot(L)_i
      auto xent = sum(
          sub(lse, where(is_true_label, tv2, fusion.zeroVal(DataType::Float))),
          {1});

      fusion.addOutput(xent);

      // NOTE: we might also want to save this output for backward
      auto yhatminusy = sub(div(e, sumexp), castOp(DataType::Float, is_true_label));
      fusion.addOutput(castOp(DataType::Half, yhatminusy));

However, this fails during scheduleOutputTensor, which expects tiled 2D outputs:

C++ exception with description "tile_size_m == gemm_tile.cta_tile.m INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/sc heduler/matmul.cpp":591, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Actual tile size at axis(-2) in output tensor is different from CTA tile size! Expected: 128, actual: 32

To address this, we need to ensure that we remove any explicit or implicit assumptions that every fusion output is 2D. That might mean we just need to pick appropriate epilogue reference tensors and schedule those then propagate.

jacobhinkle commented 2 months ago

Edited the example to show the tensor we might want to save for backward, which is a full matrix. That matrix would be saved in the loss kernel in the unfused case as well. Overall fusion would save IO proportional to 2MN as expected by not saving the intermediate GEMM output.

YouJiacheng commented 2 months ago

Is it possible to support row-wise reduction epilogue when matmul is performed tile by tile? I am curious.

jacobhinkle commented 1 month ago

Is it possible to support row-wise reduction epilogue when matmul is performed tile by tile? I am curious.

It seems like it should be possible using the usual grid reduction approaches either using atomics or semaphores to coordinate the reduction. If the reduction result is used to compute other values in the epilogue, then we would require a cooperative launch to ensure that all CTAs in a row of tiles are placed in the same wave so that we can broadcast the resulting reduction back across the grid.