KnowingNothing / MatmulTutorial

A Easy-to-understand TensorOp Matmul Tutorial
Apache License 2.0
286 stars 30 forks source link

How to disable SMEMswizzle? #14

Open ziyuhuang123 opened 1 month ago

ziyuhuang123 commented 1 month ago

I see this:

enum class SmemSwizzleBits : uint8_t {
  DISABLE = 0,
  B32 = 1,
  B64 = 2,
  B128 = 3,
};

And I changed this to 0:

  // tensor_map
  utils::TmaDescriptor tensormap_a =
      utils::make_tma_copy_desc<BLOCKM, BLOCKK, 3>(
          gemm_params.A, gemm_params.M, gemm_params.K, Swizzle<3, 4, 3>{},----------------> to <0, 4, 3>
          CLUSTER_N);
  utils::TmaDescriptor tensormap_b =
      utils::make_tma_copy_desc<BLOCKN, BLOCKK, 3>(
          gemm_params.B, gemm_params.N, gemm_params.K, Swizzle<3, 4, 3>{},----------------> to <0, 4, 3>
          CLUSTER_M);

But I get error:

Copying results done!
Wrong answer! 14663841 errors! 87.4033%
Average diff = 596.463
test: ../../../include/common.h:396: void assert_allclose(DType*, DType*, std::vector<int>, float, bool) [with DType = __half]: Assertion errors == 0' failed.
Aborted (core dumped)

Any suggestion? Thanks!

ziyuhuang123 commented 1 month ago

I tried to change here:

/// make shared memory descriptor
template <class PointerType>
DEVICE GmmaDescriptor make_smem_desc(PointerType smem_ptr) {
  GmmaDescriptor desc;
  uint32_t uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  desc.bitfield.start_address_ = uint_ptr >> 4;
  // desc.bitfield.layout_type_ =
  //     0x1;  /// swizzle 128B because we use Swizzle<3,4,3>
  desc.bitfield.layout_type_ =
    0x0;  /// swizzle disable because we use Swizzle<0,4,3>---------------------> changed to 0x0 because no swizzle
  desc.bitfield.leading_byte_offset_ = 0x1;  /// no use
  desc.bitfield.stride_byte_offset_ =
      64;  /// how many 128bits-rows needed between two core matrices
  desc.bitfield.base_offset_ = 0x0;
  return desc;
}

But the result is still incorrect....