Closed KuangjuX closed 3 months ago
Here is the implementation of a row-wise reduce operation in TK(https://github.com/HazyResearch/ThunderKittens/blob/main/src/ops/warp/register/tile/reductions.cuh#L27 ):
/**
* @brief Perform a row-wise reduction on a matrix in row-major layout.
*
* This function template performs a parallel reduction across the rows of a matrix using a specified operation.
* It leverages warp shuffle functions for efficient intra-warp communication.
*
* @tparam op The operation to be applied for reduction.
* @tparam V The vector type for the row accumulator.
* @tparam T The matrix type with row layout.
* @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not.
* @param[out] row_accum The accumulator where the result of the reduction is stored.
* @param[in] src The source matrix on which to perform the reduction.
* @param[in] src_accum The initial value of the accumulator, used when reset is false.
*/
template<typename op, ducks::rv::all V, ducks::rt::row_layout T, bool reset>
__device__ static inline void row_reduce(V &row_accum, const T &src, const V &src_accum) {
// I actually like these static asserts because they give more verbose errors when things go wrong.
static_assert(std::is_same_v<typename V::dtype, typename T::dtype>); // compatible type
static_assert(V::inner_dim == rt_base<typename T::dtype, typename T::layout>::col_vec_pack); // compatible layout
static_assert(V::outer_dim == T::height); // compatible size
using dtype = V::dtype;
const int leader = threadIdx.x & 0x1C; // 11100 in binary
#pragma unroll
for(int i = 0; i < src.height; i++) {
dtype accum_top_row = op::template op<dtype>(src.tiles[i][0].data[0], src.tiles[i][0].data[2]);
dtype accum_bottom_row = op::template op<dtype>(src.tiles[i][0].data[1], src.tiles[i][0].data[3]);
#pragma unroll
for(int j = 1; j < src.width; j++) {
#pragma unroll
for(int k = 0; k < src.packed_per_tile; k+=2) {
accum_top_row = op::template op<dtype>(accum_top_row, src.tiles[i][j].data[k+0]);
accum_bottom_row = op::template op<dtype>(accum_bottom_row, src.tiles[i][j].data[k+1]);
}
}
dtype accum_packed;
accum_packed.x = op::template op<base_types::packing<dtype>::unpacked_type>(accum_top_row.x, accum_top_row.y);
accum_packed.y = op::template op<base_types::packing<dtype>::unpacked_type>(accum_bottom_row.x, accum_bottom_row.y);
// Now we need to do a lil shuffle to make everyone happy.
accum_packed = op::template op<dtype>(accum_packed, packed_shfl_down_sync(MASK_ALL, accum_packed, 2));
accum_packed = op::template op<dtype>(accum_packed, packed_shfl_down_sync(MASK_ALL, accum_packed, 1));
accum_packed = packed_shfl_sync(MASK_ALL, accum_packed, leader);
if(reset) {
row_accum[i][0] = accum_packed;
}
else {
row_accum[i][0] = op::template op<dtype>(src_accum[i][0], accum_packed);
}
}
}
First, TK iterate through the tiles according to their row height and column width. Within each subtile, each thread stores partial data of two rows in its private registers. Initially, TK perform a partial reduction operation along the rows.
After performing the reduction within each thread, each row's threads engage in a warp shuffle with the other threads belonging to the same row. Finally, the result is computed and stored in the row_accum
array.
I would like to discuss how to add a reduce operation for RegTile. A simple reduce implementation is as follows:
This requires each thread to provide the corresponding value, and after a single warp reduce, the final result can be obtained.
However, this requires that each thread can provide the corresponding value. After a GEMM operation, the computation results are stored through
RegTile
, and at this time the reduction needs to be performed along the rows.However, the data is not stored by rows in the thread-private registers, and a single warp reduce needs to process the information of one row. I would like to discuss how the threads should specify the data dispatching?