TiledTensor / TiledCUDA

TiledCUDA is a highly efficient kernel template library designed to elevate CUDA C’s level of abstraction for processing tiles.
MIT License
148 stars 10 forks source link

feat(cell): Add a Reg level reduce based on `RegTile`. #101

Closed KuangjuX closed 3 months ago

KuangjuX commented 3 months ago

This PR implements row-major warp reduction based on RegTile, and provides RowSum and RowMax implementations based on the Reduce template functor.

I've added a simple test case for the reduce operation:

template <typename Element, typename RegLayout, typename GlobalLayout,
          typename BaseTile, typename WarpLayout, const tl::Layout kLayout,
          const copy::WarpReuse kMode, const int kHeight, const int kWidth>
__global__ void reg_reduce(Element* src) {
    using SrcLoadTile = GlobalTile<Element, GlobalLayout>;
    using DstLoadTile = RegTile<BaseTile, RegLayout>;
    using SrcReduceTile = DstLoadTile;
    using DstReduceTile = RegTile<Element, tl::RowMajor<kHeight, 2>>;

    SrcLoadTile src_load_tile(src);
    DstLoadTile dst_load_tile;
    DstReduceTile dst_reduce_tile;

    // Load data from global memory to register file
    copy::GlobalToRegLoader<DstLoadTile, WarpLayout, kMode> loader;
    loader(src_load_tile, dst_load_tile);
    __syncthreads();

    // Execute reduce operation.
    compute::MaxReduce<SrcReduceTile, kLayout> row_max;
    row_max(dst_load_tile, dst_reduce_tile);

    __syncthreads();

    if (thread(0)) {
        printf("Row Max:\n");
        printf("Thread 0:\n");
        dst_reduce_tile.dump_value();
    }

    if (thread(4)) {
        printf("Thread 4:\n");
        dst_reduce_tile.dump_value();
    }

    if (thread(8)) {
        printf("Thread 8:\n");
        dst_reduce_tile.dump_value();
    }

    __syncthreads();

    compute::SumReduce<SrcReduceTile, kLayout> row_sum;
    row_sum(dst_load_tile, dst_reduce_tile);

    __syncthreads();

    if (thread(0)) {
        printf("Row Sum:\n");
        printf("Thread 0:\n");
        dst_reduce_tile.dump_value();
    }

    if (thread(4)) {
        printf("Thread 4:\n");
        dst_reduce_tile.dump_value();
    }

    if (thread(8)) {
        printf("Thread 8:\n");
        dst_reduce_tile.dump_value();
    }
}

template <typename Element, typename RegLayout, typename GlobalLayout,
          typename BaseTile, typename WarpLayout, const tl::Layout kLayout,
          const copy::WarpReuse kMode, const int kHeight, const int kWidth>
void run_reg_reduce() {
    int kNumel = 16 * 16 * kHeight * kWidth;
    int kWarpSize = tl::get_numel<WarpLayout>;

    thrust::host_vector<Element> h_src(kNumel);
    for (int i = 0; i < kNumel; ++i) {
        h_src[i] = (Element)i;
    }

    thrust::device_vector<Element> d_src = h_src;

    reg_reduce<Element, RegLayout, GlobalLayout, BaseTile, WarpLayout, kLayout,
               kMode, kHeight, kWidth>
        <<<1, 32 * kWarpSize>>>(thrust::raw_pointer_cast(d_src.data()));
}

TEST(TestRegReduce, reg_reduce_0) {
    using Element = float;
    using WarpLayout = tl::RowMajor<1, 1>;
    using RegLayout = tl::RowMajor<1, 1>;

    const int kHeight = 1;
    const int kWidth = 1;
    const copy::WarpReuse kMode = copy::WarpReuse::kCont;

    using GlobalLayout = tl::RowMajor<16 * kHeight, 16 * kWidth>;

    run_reg_reduce<Element, RegLayout, GlobalLayout, BaseTileRowMajor<Element>,
                   WarpLayout, tl::Layout::kRowMajor, kMode, kHeight, kWidth>();
}

As shown in the code, I defined a 1*1 BaseTile that forms a RegTile, which is then processed by a single warp. During initialization, the data is assigned sequentially to facilitate debugging. In the reg_reduce function, I executed both row_max and row_sum operations and printed the results and the results are as follows:

Row Max:
Thread 0:
15.0, 143.0, 
Thread 4:
31.0, 159.0, 
Thread 8:
47.0, 175.0, 
Row Sum:
Thread 0:
120.0, 2168.0, 
Thread 4:
376.0, 2424.0, 
Thread 8:
632.0, 2680.0, 

This appears to produce the correct results. In the implementation of the reduce operation, I commented out the following two lines, as I didn't understand their purpose, and using these lines would result in incorrect results:

// TODO(KuangjuX): This line seems unnecessary?
// top_row = shuffle_down_sync(MASK_ALL, top_row, leader);
// bottom_row = shuffle_down_sync(MASK_ALL, bottom_row, leader);