TiledTensor / TiledCUDA

TiledCUDA is a highly efficient kernel template library designed to elevate CUDA C’s level of abstraction for processing tiles.
MIT License
158 stars 10 forks source link

Transfer data tile from global memory to registers. #58

Closed haruhi55 closed 4 months ago

haruhi55 commented 4 months ago

Currently, we have implemented two data feeding paths:

  1. Transfer 2D tiles from global memory to shared memory using CuTe's tiled_copy.
  2. Transfer 2D tiles from shared memory to the register file by wrapping ldmatrix.

When using CuTe's tiled_copy, users specify the thread organization and the number of elements each thread accesses along contiguous memory. To optimize I/O efficiency, the implementation leverages the maximum width for vectorized access, which is 128 bits. For half-precision data types, this corresponds to 8 elements. Consequently, while this interface can enhance performance, it imposes specific shape requirements on the input data.

Take GEMM as an example, where $C = A \times B$. Assume matrices $A$ and $C$ are laid out in row-major order, while matrix $B$ is laid out in column-major order. And we use minimal number of warps in a CTA, that is using 1 warp, 32 threads. We use the minimal number of warps in a CTA, which is 1 warp consisting of 32 threads, organized as a $4 \times 8$ tile where 4 is the size of the contiguous dimension. Under these assumptions, the following requirements apply:

$M$ $N$ $K$
$A$ $M=8$ $K=4\times 8$ (contiguous in memory)
$B$ $N=8$ $K=4 \times 8$ (contiguous in memory)
$C$ $M=8$ $N=4 \times 8$(contiguous in memory)
wmma 16 16 16
Final 16 32 32

CuTe's efficient loading imposes additional constraints on tile shapes. When adhering to these global load constraints, the tile shapes must be multiples of 32.

For two considerations, a global to register load is useful:

  1. Add an extra data feed path for performance scheduling.
  2. The tensor core's data distribution is unusual. We can conveniently generate data for unit tests on the host and bypass CuTe's implementation constraints to cover more cases, ensuring correctness.

I suspect there might be situations where using a locally sub-optimal tile shape could lead to benefits within a larger program context.