Open carlguo866 opened 4 weeks ago
You can use LDSM for any data type, not just 16b, as long as the data layouts are compatible
I looked through the source code for SM75_U32x2_LDSM_N
, which has the following:
asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(dst0), "=r"(dst1) : "r"(smem_int_ptr));
so as long as I fit 2 int8s into 16 bits, it should work?
Yes
I have investigated this issue much more closely since, and I'm still having trouble.
Specifically, I'm using Copy_Atom<SM75_U32x2_LDSM_N, int8_t>
and a TiledMMA
with MMA_Atom<SM80_16x8x16_S32S8S8S32_TN>
. When using these two to create a TiledCopy
using make_tiled_copy_A
, it generates a Tensor of layout:
(((_2,(_2,_2)),(_2,_2)),_8,(_2,_2)):(((_1,(_4,8)),(_2,_512)),_1024,(16,32))
Inspecting this layout, dim<0> doesn't seem to be mapping to continuous memory when it should be. Because of this, when we convert the 8b tensor into a 128b tensor so it can be fed into LDSM, it creates a weird layout of the following
((_2,(_2,_2))):((_1,(_4,8))) -> ((_1,(_1,_2))):((_1,(_1,0)))
This fails because LDSM only takes in size 1 tensor. I'm wondering if you have any insights as to how to correctly use LDSM copy atom with int8 TiledMMA. Thanks!
You can provide the print_latex(tiled_ldsm)
and the print_latex(your_smem_layout_before_partitioning)
.
There is an incompatibility between the data layout and the instruction you are attempting to apply. You are noticing that the instruction is partitioning the data such that it is not accessing contiguous elements.
Here are the two latex pdfs:
Tiled LDSM
tiled_ldsm.pdf
Smem Before (truncated due to compile time, should be 64 x 128):
smem_before.pdf
It's unclear to me how I should read this your_smem_layout_before_partitioning
graph, does the color mean anything? From the Tiled LDSM, it does seem like the partition is not continuous, but I'm not sure how to make it continuous though.
You change the shared memory layout so that it is compatible with the LDSM partitioning.
I can't reproduce your TiledCopy
, it seems.
auto mma = make_tiled_mma(MMA_Atom<SM80_16x8x16_S32S8S8S32_TN>{},
Layout<Shape<_4,_1>>{});
auto copy = make_tiled_copy_A(Copy_Atom<SM75_U32x2_LDSM_N, int8_t>{}, mma);
print_latex(copy);
produces which is different than what you posted and should be fine for your (swizzled) row-major smem layout.
I fixed a bug in my code and was able to get the layout above. Thanks so much for the help!
However, now, the layout of the 128bit src tensor converted from the 8bit tensor is still weird.
For the 8b tensor, the layout is: (((_8,_2),_1)):(((_1,8),_0))
while the 128 bit tensor is of layout (((_1,_2),_1)):(((_1,0),_0))
. The extra dimension pointing at the same address is causing the LDSM to fail. I'm not sure how to resolve this issue.
I noticed that the stride of size<1>(size<0>)
, 8, is not static - not sure if this is part of the issue.
You have a dynamic stride, which is preventing the proper transformation. You need to post your original pre-partitioned smem layout and how you're constructing it.
Minimal code looks like this, adapted from Flash Attention.
static constexpr int kBlockM = 128;
static constexpr int kHeadDim = 64;
static constexpr int kBlockKSmem = 64;
static constexpr int kSwizzle = 3;
using SmemLayoutAtomQ = decltype(composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<_16, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutQ = decltype(
tile_to_shape(SmemLayoutAtomQ{}, Shape<Int<kBlockM>, Int<kHeadDim>>{}));
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<int8_t*>(smem_)),
typename SmemLayoutQ{});
sQ
seems to have all static layout, which is S<3,3,3> o _0 o (_128,_64):(_64,_1)
Then the previous 8b tensor layout is just tSsQ = copy.partition_S(sQ);
First, the swizzle has an MBase of 3, so only 8 elements are contiguous instead of the 16 that you intend. Using a swizzle of Swizzle<3,3,4>
would make your code functional at least. (I suppose we could/should guard against intra-atom dynamic layouts within the copy_unpack
/mma_unpack
as they cannot be proven to be safe... Though the existing guards did catch this as a failed recast
, correct?)
Second, a K-major + swizzle layout will almost always cause the partition to be dynamic, which can cost some performance. I recommend using a padded interleaved layout for more efficiency instead:
Tensor sQ = make_tensor(ptr,
Shape <_128,Shape <_16, _4>>{},
Stride< _16,Stride< _1,Int<(128+1)*16>>{}); // Might want +2 here, actually, for perfect banking
Swizzle<3,3,4>
doesn't work out-of-the-box... it still produces the layout (((_1,_2),_1)):(((_1,0),_0))
. The existing guards does catch this, not as a failed recast
, but the copy atom checks if the tensor is of size<1>
before executing the assembly. I'll investigate the padded interleaved layout in the meantime. Thanks!
Update: realized that it should be Swizzle<3,4,3>
as MBase
is the second parameter.
Whoops, Swizzle<3,4,3>
to set the swizzle mbase to 4.
I've taken a closer look and the Swizzle<3,4,3>
doesn't seem to create dynamic strides, but when I try to use the transposed LDSM SM75_U16x4_LDSM_T
, the transposed smem layout does have dynamic stride. Is there a way to resolve this easily, or should I just transition to padded interleaved layouts altogether? Here's how I set the transposed smem up:
using SmemLayoutKV = decltype(
tile_to_shape(SmemLayoutAtomQ{}, Shape<Int<kBlockN>, Int<kHeadDim>>{})); // same as Q, just N instead of M
Tensor sV = make_tensor(ptr, SmemLayoutKV{});
using SmemLayoutAtomVtransposed = decltype(composition(
Swizzle<kSwizzle, 4, 3>{}, Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>));
using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{}, Shape<Int<kHeadDim>, Int<kBlockN>>{}));
Tensor sVt =
make_tensor(sV.data(), SmemLayoutVtransposed{});
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x4_LDSM_T, int8_t>;
auto smem_tiled_copy_V = make_tiled_copy_B(
SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
tOsVt
produces the following layout (((_2,_8),_1),(_2,_2),_8):(((64,_1),_0),(16,32),_1024)
, which is preventing me from recast
into a 128b tensor for the LDSM to work.
SM75_U16x4_LDSM_T
has completely different partitioning patterns and access widths than SM75_U32x2_LDSM_N
, so nothing in our previous discussion above applies to it. You would go through the same process: print_latex
the patterns and engineer a good SMEM layout for that pattern. There should be a few obvious ones (row-major vs col-major, etc), then a few ways to improve bank access patterns, and only then perhaps a clever swizzle layout.
What is your question? I saw that
SM75_U32x2_LDSM_N
and such support LDSM for types of 16bit, but does similar instructions support int8? If not, is there a more performantCopyAtom
thanDefaultCopy
?