TiledTensor / TiledCUDA

TiledCUDA is a highly efficient kernel template library designed to elevate CUDA C’s level of abstraction for processing tiles.
MIT License
157 stars 10 forks source link

Add Warp Reduce based on `RegTile`. #99

Closed KuangjuX closed 3 months ago

KuangjuX commented 3 months ago

I would like to discuss how to add a reduce operation for RegTile. A simple reduce implementation is as follows:

template <typename Reduce>
DEVICE DType warp_reduce(DType data, unsigned mask, Reduce reduce) {
#pragma unroll
    for (int offset = 16; offset >= 1; offset /= 2) {
        data = reduce(data, __shfl_down_sync(mask, data, offset, 32));
    }

    return data;
}

This requires each thread to provide the corresponding value, and after a single warp reduce, the final result can be obtained.

However, this requires that each thread can provide the corresponding value. After a GEMM operation, the computation results are stored through RegTile, and at this time the reduction needs to be performed along the rows.

However, the data is not stored by rows in the thread-private registers, and a single warp reduce needs to process the information of one row. I would like to discuss how the threads should specify the data dispatching?

KuangjuX commented 3 months ago

Here is the implementation of a row-wise reduce operation in TK(https://github.com/HazyResearch/ThunderKittens/blob/main/src/ops/warp/register/tile/reductions.cuh#L27 ):

/**
 * @brief Perform a row-wise reduction on a matrix in row-major layout.
 *
 * This function template performs a parallel reduction across the rows of a matrix using a specified operation.
 * It leverages warp shuffle functions for efficient intra-warp communication.
 *
 * @tparam op The operation to be applied for reduction.
 * @tparam V The vector type for the row accumulator.
 * @tparam T The matrix type with row layout.
 * @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not.
 * @param[out] row_accum The accumulator where the result of the reduction is stored.
 * @param[in] src The source matrix on which to perform the reduction.
 * @param[in] src_accum The initial value of the accumulator, used when reset is false.
 */
template<typename op, ducks::rv::all V, ducks::rt::row_layout T, bool reset>
__device__ static inline void row_reduce(V &row_accum, const T &src, const V &src_accum) {
    // I actually like these static asserts because they give more verbose errors when things go wrong.
    static_assert(std::is_same_v<typename V::dtype, typename T::dtype>); // compatible type
    static_assert(V::inner_dim == rt_base<typename T::dtype, typename T::layout>::col_vec_pack); // compatible layout
    static_assert(V::outer_dim == T::height); // compatible size

    using dtype = V::dtype;

    const int leader = threadIdx.x & 0x1C; // 11100 in binary
    #pragma unroll
    for(int i = 0; i < src.height; i++) {
        dtype accum_top_row    = op::template op<dtype>(src.tiles[i][0].data[0], src.tiles[i][0].data[2]);
        dtype accum_bottom_row = op::template op<dtype>(src.tiles[i][0].data[1], src.tiles[i][0].data[3]);
        #pragma unroll
        for(int j = 1; j < src.width; j++) {
            #pragma unroll
            for(int k = 0; k < src.packed_per_tile; k+=2) {
                accum_top_row    = op::template op<dtype>(accum_top_row,    src.tiles[i][j].data[k+0]);
                accum_bottom_row = op::template op<dtype>(accum_bottom_row, src.tiles[i][j].data[k+1]);
            }
        }
        dtype accum_packed;
        accum_packed.x = op::template op<base_types::packing<dtype>::unpacked_type>(accum_top_row.x,    accum_top_row.y);
        accum_packed.y = op::template op<base_types::packing<dtype>::unpacked_type>(accum_bottom_row.x, accum_bottom_row.y);

        // Now we need to do a lil shuffle to make everyone happy.

        accum_packed = op::template op<dtype>(accum_packed, packed_shfl_down_sync(MASK_ALL, accum_packed, 2));
        accum_packed = op::template op<dtype>(accum_packed, packed_shfl_down_sync(MASK_ALL, accum_packed, 1));

        accum_packed = packed_shfl_sync(MASK_ALL, accum_packed, leader);

        if(reset) {
            row_accum[i][0] = accum_packed;
        }
        else {
            row_accum[i][0] = op::template op<dtype>(src_accum[i][0], accum_packed);
        }
    }
}

First, TK iterate through the tiles according to their row height and column width. Within each subtile, each thread stores partial data of two rows in its private registers. Initially, TK perform a partial reduction operation along the rows.

After performing the reduction within each thread, each row's threads engage in a warp shuffle with the other threads belonging to the same row. Finally, the result is computed and stored in the row_accum array.