ROCm / composable_kernel

Composable Kernel: Performance Portable Programming Model for Machine Learning Tensor Operators
https://rocm.docs.amd.com/projects/composable_kernel/en/latest/
Other
322 stars 129 forks source link

[Development] How to create a temp tile window from a pointer? #1586

Closed LeiWang1999 closed 1 month ago

LeiWang1999 commented 1 month ago

Problem Description

Hi all, we're working on a tile language similar to triton, and I'm supporting hip codegen along with composable kernel (for cuda, we used cute).

template <int M, int N, int K, bool TransposeA, bool TransposeB, typename A_type_raw,
          typename B_type_raw, typename C_type_raw>
class GemmTensorOp {
 public:
  using A_type = A_type_raw;
  using B_type = B_type_raw;
  using C_type = C_type_raw;
//   using Instruction = DispatchInstruction<A_type, B_type, C_type>;
//   using DeviceGemmInstance = typename Instruction::DeviceGemmInstance;

  struct Problem {
    // This part comes from the Codegen
    static constexpr ck_tile::index_t M_Tile = 128;
    static constexpr ck_tile::index_t N_Tile = 128;
    static constexpr ck_tile::index_t K_Tile = 32;

    static constexpr ck_tile::index_t M_Warp = 2;
    static constexpr ck_tile::index_t N_Warp = 2;
    static constexpr ck_tile::index_t K_Warp = 1;

    static constexpr ck_tile::index_t M_Warp_Tile = 32;
    static constexpr ck_tile::index_t N_Warp_Tile = 32;
    static constexpr ck_tile::index_t K_Warp_Tile = 8;

    using ADataType = A_type;
    using BDataType = B_type;
    using CDataType = C_type;
    using BlockGemmShape =
        ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
                               ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
                               ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
  };

  static CK_TILE_DEVICE void body(A_type* pA, B_type* pB, C_type* pC) {

    static constexpr auto I0 = ck_tile::number<0>{};
    static constexpr auto I1 = ck_tile::number<1>{};
    static constexpr auto I2 = ck_tile::number<2>{};

    // Load A and B from Shared memory
    using AccDataType = float;
    using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
    using WarpTile = typename Problem::BlockGemmShape::WarpTile;
    using WarpGemm =
        ck_tile::WarpGemmMfmaDispatcher<typename Problem::ADataType, typename Problem::BDataType,
                               AccDataType, WarpTile::at(I0), WarpTile::at(I1), WarpTile::at(I2),
                               false>;
    using BlockGemmPolicy =
        ck_tile::BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
                                              typename Problem::BDataType,
                                              typename Problem::CDataType, BlockWarps, WarpGemm>;
    auto compute = ck_tile::BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{};
    // Then do GEMM and store the result to C
    // Create Tile Window
    auto a_block_window =
    compute(c_block_tensor, a_block_window, b_block_window);

  }
};

}

Do we have any method to create a tile window for the computation?

Operating System

Ubuntu 20.04

CPU

AMD

GPU

AMD Instinct MI250

Other

No response

ROCm Version

ROCm 5.7.0

ROCm Component

Other

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

LeiWang1999 commented 1 month ago
  static CK_TILE_DEVICE void body(A_type* pA, B_type* pB, C_type* pC) {

    static constexpr auto I0 = ck_tile::number<0>{};
    static constexpr auto I1 = ck_tile::number<1>{};
    static constexpr auto I2 = ck_tile::number<2>{};

    // Load A and B from Shared memory
    using AccDataType = float;
    using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
    using WarpTile = typename Problem::BlockGemmShape::WarpTile;
    using WarpGemm =
        ck_tile::WarpGemmMfmaDispatcher<typename Problem::ADataType, typename Problem::BDataType,
                               AccDataType, WarpTile::at(I0), WarpTile::at(I1), WarpTile::at(I2),
                               false>;
    using BlockGemmPolicy =
        ck_tile::BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
                                              typename Problem::BDataType,
                                              typename Problem::CDataType, BlockWarps, WarpGemm>;
    auto block_gemm = ck_tile::BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{};
    // Then do GEMM and store the result to C
    // Create Tile Window

    constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
    constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
    constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;

    constexpr auto a_lds_block_desc =
        make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kKPerBlock));

    auto a_lds_block = make_tensor_view<address_space_enum::lds>(pA, a_lds_block_desc);

    constexpr auto b_lds_block_desc =
        make_naive_tensor_descriptor_packed(make_tuple(kKPerBlock, kNPerBlock));

    auto b_lds_block = make_tensor_view<address_space_enum::lds>(pB, b_lds_block_desc);

    // A LDS tile for block GEMM
    auto a_lds_gemm_window = make_tile_window(
        a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});

    // B LDS tile for block GEMM
    auto b_lds_gemm_window = make_tile_window(
        b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});

    // Acc register tile
    auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};   
    compute(c_block_tensor, a_block_window, b_block_window);

  }
};

This solution looks good to me.

LeiWang1999 commented 1 month ago
  static CK_TILE_DEVICE void body(A_type* pA, B_type* pB, C_type* pC) {
    static constexpr auto I0 = ck_tile::number<0>{};
    static constexpr auto I1 = ck_tile::number<1>{};
    static constexpr auto I2 = ck_tile::number<2>{};

    // Load A and B from Shared memory
    using AccDataType = float;
    using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
    using WarpTile = typename Problem::BlockGemmShape::WarpTile;
    using WarpGemm =
        ck_tile::WarpGemmMfmaDispatcher<typename Problem::ADataType, typename Problem::BDataType,
                                        AccDataType, WarpTile::at(I0), WarpTile::at(I1),
                                        WarpTile::at(I2), false>;
    using BlockGemmPolicy = ck_tile::BlockGemmASmemBSmemCRegV1CustomPolicy<
        typename Problem::ADataType, typename Problem::BDataType, typename Problem::CDataType,
        BlockWarps, WarpGemm>;
    auto block_gemm = ck_tile::BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{};
    // Then do GEMM and store the result to C
    // Create Tile Window

    constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
    constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
    constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;

    constexpr auto a_lds_block_desc =
        make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kKPerBlock));

    auto a_lds_block = make_tensor_view<address_space_enum::lds>(pA, a_lds_block_desc);

    constexpr auto b_lds_block_desc =
        make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock));

    auto b_lds_block = make_tensor_view<address_space_enum::lds>(pB, b_lds_block_desc);

    // A LDS tile for block SMEM
    auto a_lds_gemm_window = make_tile_window(
        a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});

    // B LDS tile for block SMEM
    auto b_lds_gemm_window = make_tile_window(
        b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});

    // Acc register tile
    auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};

    block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
    // print pa

    // print c
    if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
        printf("============c========\n");
        for(int i = 0; i < 2; i++) {
            for(int j = 0; j < 128; j++) {
                printf("%f ", (float)pC[i * 128 + j]);
            }
            printf("\n");
        }
    }

  }

Any body has the idea of connecting c_block_tile with local pointer pC?

LeiWang1999 commented 1 month ago

Currently, the thread_buffer design holds an internal array (value_type data[N]), which leads to unnecessary consumption of register resources and redundant memory copies when initializing thread_buffer instances, especially in performance-critical contexts like multi-threading. Every time we create an instance, it allocates storage for the array, even when an external data buffer might already exist.

To improve this, I propose introducing a value_type* data_ptr, which allows users to directly modify the internal data pointer, offering more flexibility. However, this needs careful consideration due to template-based behavior and pointer ownership concerns.

Potential Revision

Instead of directly adding a data_ptr, we can introduce an optional mechanism to switch between internal array storage and external memory management, depending on whether the user provides a pointer or not. This way, we maintain backward compatibility while also offering a more flexible design. Here’s an example:

template<typename T_, index_t N_>
struct thread_buffer {
    using value_type = std::remove_cvref_t<T_>;
    static constexpr index_t N = N_;

    value_type* data_ptr = nullptr; // Pointer to allow external memory management
    bool owns_memory = true;        // Flag to track whether thread_buffer owns the memory

    // Default constructor: internally manages memory
    thread_buffer() : data_ptr(internal_data) {}

    // Constructor that accepts an external data pointer
    explicit thread_buffer(value_type* external_ptr) 
        : data_ptr(external_ptr), owns_memory(false) {}

    // Destructor: only deallocates if we own the memory
    ~thread_buffer() {
        if (owns_memory) {
            // Optional: manage deallocation here if necessary
        }
    }

    // Accessor functions for backward compatibility
    value_type& operator[](index_t i) {
        return data_ptr[i];
    }

    const value_type& operator[](index_t i) const {
        return data_ptr[i];
    }

private:
    value_type internal_data[N]; // Internal array for default use
};

Key Points:

This approach ensures minimal disruption to the existing design while offering more flexibility for advanced use cases that involve custom memory management.

carlushuang commented 1 month ago

@LeiWang1999 thanks for reaching out and have some suggestions. In fact modern compiler are clever enough to remove all the temporary created thread_buffer, we call this technique as local scratch internally, and used everywhere inside ck internal source code. But the idea to create a non-holding tensor is good, we can consider this as an option for tensor view design

LeiWang1999 commented 1 month ago

@carlushuang thanks for your response! Just to confirm, are you saying modern compiler can automatically remove redundant copy like this?

  c_vec = c_block_tile.get_thread_buffer().template get_as<CVec>(number<0>{});

  for(int j = 0; j < c_block_tile.get_thread_buffer_size(); j++) {
      pC[j] = c_block_tile.get_thread_buffer()[j];
  }

Is there a way for us to check if hipcc is applying this optimization?

carlushuang commented 1 month ago

exactly. The best way is to dump the ISA and directly check what is generated, to see if there is any unnecessary strange v_mov_b32 instruction to move data from one register to another. There should also exist some approach to check the generated IR, but I didn't explore it too much. hipcc support is out-of-box since we are llvm based compiler

LeiWang1999 commented 1 month ago

@carlushuang thanks for your response, would you mind provide the command that using hipcc to dump the isa :) ?

taylding-amd commented 1 month ago

Hi @LeiWang1999 , you can compile your HIP code to assembly using the -S option. This will generate an assembly file, which may contain the ISA information you need. hipcc -S your_file.cpp -o output_file