Currently, we have implemented two data feeding paths:
Transfer 2D tiles from global memory to shared memory using CuTe's tiled_copy.
Transfer 2D tiles from shared memory to the register file by wrapping ldmatrix.
When using CuTe's tiled_copy, users specify the thread organization and the number of elements each thread accesses along contiguous memory. To optimize I/O efficiency, the implementation leverages the maximum width for vectorized access, which is 128 bits. For half-precision data types, this corresponds to 8 elements. Consequently, while this interface can enhance performance, it imposes specific shape requirements on the input data.
Take GEMM as an example, where $C = A \times B$. Assume matrices $A$ and $C$ are laid out in row-major order, while matrix $B$ is laid out in column-major order. And we use minimal number of warps in a CTA, that is using 1 warp, 32 threads. We use the minimal number of warps in a CTA, which is 1 warp consisting of 32 threads, organized as a $4 \times 8$ tile where 4 is the size of the contiguous dimension. Under these assumptions, the following requirements apply:
$M$
$N$
$K$
$A$
$M=8$
$K=4\times 8$ (contiguous in memory)
$B$
$N=8$
$K=4 \times 8$ (contiguous in memory)
$C$
$M=8$
$N=4 \times 8$(contiguous in memory)
wmma
16
16
16
Final
16
32
32
CuTe's efficient loading imposes additional constraints on tile shapes. When adhering to these global load constraints, the tile shapes must be multiples of 32.
For two considerations, a global to register load is useful:
Add an extra data feed path for performance scheduling.
The tensor core's data distribution is unusual. We can conveniently generate data for unit tests on the host and bypass CuTe's implementation constraints to cover more cases, ensuring correctness.
I suspect there might be situations where using a locally sub-optimal tile shape could lead to benefits within a larger program context.
Currently, we have implemented two data feeding paths:
tiled_copy
.ldmatrix
.When using CuTe's
tiled_copy
, users specify the thread organization and the number of elements each thread accesses along contiguous memory. To optimize I/O efficiency, the implementation leverages the maximum width for vectorized access, which is 128 bits. For half-precision data types, this corresponds to 8 elements. Consequently, while this interface can enhance performance, it imposes specific shape requirements on the input data.Take GEMM as an example, where $C = A \times B$. Assume matrices $A$ and $C$ are laid out in row-major order, while matrix $B$ is laid out in column-major order. And we use minimal number of warps in a CTA, that is using 1 warp, 32 threads. We use the minimal number of warps in a CTA, which is 1 warp consisting of 32 threads, organized as a $4 \times 8$ tile where 4 is the size of the contiguous dimension. Under these assumptions, the following requirements apply:
wmma
CuTe's efficient loading imposes additional constraints on tile shapes. When adhering to these global load constraints, the tile shapes must be multiples of 32.
For two considerations, a global to register load is useful:
I suspect there might be situations where using a locally sub-optimal tile shape could lead to benefits within a larger program context.