NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.02k stars 857 forks source link

[QST] LDSM Copy for int8 #1611

Open carlguo866 opened 4 weeks ago

carlguo866 commented 4 weeks ago

What is your question? I saw that SM75_U32x2_LDSM_N and such support LDSM for types of 16bit, but does similar instructions support int8? If not, is there a more performant CopyAtom than DefaultCopy?

thakkarV commented 4 weeks ago

You can use LDSM for any data type, not just 16b, as long as the data layouts are compatible

carlguo866 commented 4 weeks ago

I looked through the source code for SM75_U32x2_LDSM_N, which has the following:

asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(dst0), "=r"(dst1) : "r"(smem_int_ptr));

so as long as I fit 2 int8s into 16 bits, it should work?

thakkarV commented 4 weeks ago

Yes

carlguo866 commented 3 weeks ago

I have investigated this issue much more closely since, and I'm still having trouble. Specifically, I'm using Copy_Atom<SM75_U32x2_LDSM_N, int8_t> and a TiledMMA with MMA_Atom<SM80_16x8x16_S32S8S8S32_TN>. When using these two to create a TiledCopy using make_tiled_copy_A, it generates a Tensor of layout:

(((_2,(_2,_2)),(_2,_2)),_8,(_2,_2)):(((_1,(_4,8)),(_2,_512)),_1024,(16,32))

Inspecting this layout, dim<0> doesn't seem to be mapping to continuous memory when it should be. Because of this, when we convert the 8b tensor into a 128b tensor so it can be fed into LDSM, it creates a weird layout of the following

((_2,(_2,_2))):((_1,(_4,8))) -> ((_1,(_1,_2))):((_1,(_1,0)))

This fails because LDSM only takes in size 1 tensor. I'm wondering if you have any insights as to how to correctly use LDSM copy atom with int8 TiledMMA. Thanks!

ccecka commented 3 weeks ago

You can provide the print_latex(tiled_ldsm) and the print_latex(your_smem_layout_before_partitioning).

There is an incompatibility between the data layout and the instruction you are attempting to apply. You are noticing that the instruction is partitioning the data such that it is not accessing contiguous elements.

carlguo866 commented 3 weeks ago

Here are the two latex pdfs:

Tiled LDSM tiled_ldsm.pdf Smem Before (truncated due to compile time, should be 64 x 128): smem_before.pdf It's unclear to me how I should read this your_smem_layout_before_partitioning graph, does the color mean anything? From the Tiled LDSM, it does seem like the partition is not continuous, but I'm not sure how to make it continuous though.

ccecka commented 3 weeks ago

You change the shared memory layout so that it is compatible with the LDSM partitioning.

I can't reproduce your TiledCopy, it seems.

  auto mma = make_tiled_mma(MMA_Atom<SM80_16x8x16_S32S8S8S32_TN>{},
                            Layout<Shape<_4,_1>>{});
  auto copy = make_tiled_copy_A(Copy_Atom<SM75_U32x2_LDSM_N, int8_t>{}, mma);
  print_latex(copy);

produces image which is different than what you posted and should be fine for your (swizzled) row-major smem layout.

carlguo866 commented 3 weeks ago

I fixed a bug in my code and was able to get the layout above. Thanks so much for the help!

However, now, the layout of the 128bit src tensor converted from the 8bit tensor is still weird.

For the 8b tensor, the layout is: (((_8,_2),_1)):(((_1,8),_0)) while the 128 bit tensor is of layout (((_1,_2),_1)):(((_1,0),_0)). The extra dimension pointing at the same address is causing the LDSM to fail. I'm not sure how to resolve this issue.

I noticed that the stride of size<1>(size<0>), 8, is not static - not sure if this is part of the issue.

ccecka commented 3 weeks ago

You have a dynamic stride, which is preventing the proper transformation. You need to post your original pre-partitioned smem layout and how you're constructing it.

carlguo866 commented 3 weeks ago

Minimal code looks like this, adapted from Flash Attention.

static constexpr int kBlockM = 128; 
static constexpr int kHeadDim = 64; 
static constexpr int kBlockKSmem = 64; 
static constexpr int kSwizzle = 3;
using SmemLayoutAtomQ = decltype(composition(Swizzle<kSwizzle, 3, 3>{},
      Layout<Shape<_16, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutQ = decltype(
      tile_to_shape(SmemLayoutAtomQ{}, Shape<Int<kBlockM>, Int<kHeadDim>>{}));
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<int8_t*>(smem_)),
      typename SmemLayoutQ{});

sQ seems to have all static layout, which is S<3,3,3> o _0 o (_128,_64):(_64,_1) Then the previous 8b tensor layout is just tSsQ = copy.partition_S(sQ);

ccecka commented 3 weeks ago

First, the swizzle has an MBase of 3, so only 8 elements are contiguous instead of the 16 that you intend. Using a swizzle of Swizzle<3,3,4> would make your code functional at least. (I suppose we could/should guard against intra-atom dynamic layouts within the copy_unpack/mma_unpack as they cannot be proven to be safe... Though the existing guards did catch this as a failed recast, correct?)

Second, a K-major + swizzle layout will almost always cause the partition to be dynamic, which can cost some performance. I recommend using a padded interleaved layout for more efficiency instead:

Tensor sQ = make_tensor(ptr,
                        Shape <_128,Shape <_16,            _4>>{},
                        Stride< _16,Stride< _1,Int<(128+1)*16>>{});   // Might want +2 here, actually, for perfect banking
carlguo866 commented 3 weeks ago

Swizzle<3,3,4> doesn't work out-of-the-box... it still produces the layout (((_1,_2),_1)):(((_1,0),_0)). The existing guards does catch this, not as a failed recast, but the copy atom checks if the tensor is of size<1> before executing the assembly. I'll investigate the padded interleaved layout in the meantime. Thanks!

Update: realized that it should be Swizzle<3,4,3> as MBase is the second parameter.

ccecka commented 3 weeks ago

Whoops, Swizzle<3,4,3> to set the swizzle mbase to 4.

carlguo866 commented 3 weeks ago

I've taken a closer look and the Swizzle<3,4,3> doesn't seem to create dynamic strides, but when I try to use the transposed LDSM SM75_U16x4_LDSM_T, the transposed smem layout does have dynamic stride. Is there a way to resolve this easily, or should I just transition to padded interleaved layouts altogether? Here's how I set the transposed smem up:

using SmemLayoutKV = decltype(
      tile_to_shape(SmemLayoutAtomQ{}, Shape<Int<kBlockN>, Int<kHeadDim>>{})); // same as Q, just N instead of M
Tensor sV = make_tensor(ptr, SmemLayoutKV{}); 

using SmemLayoutAtomVtransposed = decltype(composition(
      Swizzle<kSwizzle, 4, 3>{}, Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
          Stride<_1, Int<kBlockKSmem>>>));
using SmemLayoutVtransposed = decltype(tile_to_shape(
      SmemLayoutAtomVtransposed{}, Shape<Int<kHeadDim>, Int<kBlockN>>{}));
Tensor sVt =
      make_tensor(sV.data(), SmemLayoutVtransposed{});

using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x4_LDSM_T, int8_t>;
auto smem_tiled_copy_V = make_tiled_copy_B(
      SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);

tOsVt produces the following layout (((_2,_8),_1),(_2,_2),_8):(((64,_1),_0),(16,32),_1024), which is preventing me from recast into a 128b tensor for the LDSM to work.

ccecka commented 3 weeks ago

SM75_U16x4_LDSM_T has completely different partitioning patterns and access widths than SM75_U32x2_LDSM_N, so nothing in our previous discussion above applies to it. You would go through the same process: print_latex the patterns and engineer a good SMEM layout for that pattern. There should be a few obvious ones (row-major vs col-major, etc), then a few ways to improve bank access patterns, and only then perhaps a clever swizzle layout.