NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
256 stars 51 forks source link

Self-mapping error when compiling 1D bias linear fusion with smem-epilogue and no split-K #2979

Open jacobhinkle opened 1 week ago

jacobhinkle commented 1 week ago

The following test fails currently:

TEST_F(MatmulSchedulerTest, SelfMappingErrorSmemEpilogue1dBias) {
  NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);

  Fusion fusion_obj;
  Fusion* fusion = &fusion_obj;
  FusionGuard fg(fusion);

  auto tv0 = makeContigTensor(2, DataType::Half);
  auto tv1 = makeContigTensor(2, DataType::Half);
  fusion->addInput(tv0);
  fusion->addInput(tv1);

  auto tv2 = makeContigTensor(1, DataType::Half);
  fusion->addInput(tv2);
  TensorView* tv3 = linear(tv0, tv1, tv2);

  fusion->addOutput(tv3);

  MatMulTileOptions gemm_tile;
  gemm_tile.cta_tile = GemmTile(64, 128, 32);
  gemm_tile.warp_tile = GemmTile(32, 32, 32);
  gemm_tile.instruction_tile = GemmTile(16, 8, 16);

  MatmulParams params;
  params.supported_vec_size = {8, 8, 4};
  params.mma_macro = MmaMacro::Ampere_16_8_16;
  params.tile_sizes = gemm_tile;
  params.async_gmem_load_operands = true;
  params.circular_buffer_options.circular_buffer_smem_write = true;
  params.circular_buffer_options.circular_buffer_smem_read = true;
  params.circular_buffer_options.smem_circular_buffer_stage = 2;
  mma_utils::MmaDataTypes data_types = {
      DataType::Half, DataType::Half, DataType::Float};
  std::tie(params.use_smem_epilogue, params.promote_prologue_smem_reuse) =
      mma_utils::generateSharedMemoryEpilogueHeuristics(
          gemm_tile,
          params.circular_buffer_options.smem_circular_buffer_stage,
          data_types,
          /*ignore_occupancy_drop=*/true);
  scheduleMatmul(fusion, params);

  auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
  int64_t M = 256, N = 128, K = 64;
  at::Tensor t0 = at::randn({M, K}, options);
  at::Tensor t1 = at::randn({N, K}, options);
  at::Tensor t2 = at::randn({K}, options);
  std::vector<c10::IValue> inputs{t0, t1, t2};

  FusionExecutor fe;
  fe.compileFusion(fusion, inputs);
}

This gives the following error:

C++ exception with description " INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/id_model/id_model.cpp":912, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Detected loop domains are mapped in the loop graph. Tensor: T8_l___half[ iS53{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 64) )}, bS55{( ceilDiv(1, 128) )}, iS57{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[1] ), 32) )}, iS125{( ceilDiv(32, 16) )}, ithreadIdx.z127{( ceilDiv(64, 32) )}, bthreadIdx.y129{( ceilDiv(128, 32) )}, iS131{( ceilDiv(32, 16) )}, bS133{( ceilDiv(32, 8) )}, bS134{8}, iS199{( ceilDiv(( ceilDiv(16, 8) ), 2) )}, ithreadIdx.x239{( 8 * 4 )}, iS197{( ceilDiv(( ceilDiv(16, 2) ), 4) )}, iS200{2}, iS196{2} ] ca_pos( 5 ) produce_pos( 5 ). Mapped loop domains: ithreadIdx.z127{( ceilDiv(64, 32) )} and bthreadIdx.y129{( ceilDiv(128, 32) )} Exception raised from validateLoopGraphHasNoSelfMappedLeafDomains at /opt/pytorch/nvfuser/csrc/id_model/id_model.cpp:912 (most recent call first):

The tensor T8 is the mma instruction's "A" input after translating the LinearOp node.

We succeed in compiling this fusion whenever we:

zasdfgbnm commented 1 week ago

Cross-posting https://github.com/NVIDIA/Fuser/pull/2669#discussion_r1691945234