NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.63k stars 962 forks source link

[QST] Copy Accumulator to GMEM directly? #1920

Closed osayamenja closed 2 days ago

osayamenja commented 2 days ago

What is your question? Hello! How do you copy the accumulator registers to global memory directly?

For example, in ampere_conv_kernel.h, how would we copy accum directly to gC, skipping the copy to sC?

Thanks!

osayamenja commented 2 days ago

I have tried something like the below but it fails to compile with "Copy_Traits: src failed to vectorize into registers. Layout is incompatible with this CopyOp."

// ... 
copy(gmem_tiled_copy_C, accum, tDgC);
thakkarV commented 2 days ago

please see https://github.com/NVIDIA/cutlass/issues/1905

osayamenja commented 2 days ago

@thakkarV Thanks for responding! My mistake for not giving enough information. My use case is actually different from that, as there is no vectorization.

Here is the changed tiled copy.

auto gmem_tiled_copy_C = cute::make_tiled_copy(
        cute::Copy_Atom<cute::UniversalCopy<float>, float>{},
        cute::Layout<cute::Shape<cute::_16, cute::_8>>{},
        cute::Layout<cute::Shape<cute::_1, cute::_1>>{}); // 1x1 per thread, is this the problem?

Below are the layouts.

((_2,_2),_4,_4):((_1,_2),_4,_16) //accum
--------------------------------------------
TiledCopy // gmem_tiled_copy_C
  Tiler_MN:       (_16,_8)
  TiledLayout_TV: (_128,_1):(_1,_0)
Copy_Atom
  ThrID:        _1:_0
  ValLayoutSrc: (_1,_1):(_0,_1)
  ValLayoutDst: (_1,_1):(_0,_1)
  ValLayoutRef: (_1,_1):(_0,_1)
  ValueType:    32b
--------------------------------------------
gmem_ptr[32b](0x420001800) o ((_1,_1),_8,_8):((_0,_0),16,1024) // tDgC
thakkarV commented 2 days ago

if you don't care about vectorization, just drop the tiled copy. Partition the gmem tensor with the tiled mma and then just call copy on the partitioned rmem tensor.

auto tCrC = thr_mma.partition_fragment_C(TileShapeMN{});
auto tCgC = thr_mma.partition_C(tiled_gmem_tensor_C);
copy(tCrC, tCgC);
osayamenja commented 2 days ago

@thakkarV Life saver, thanks a ton! It compiles now!

Honestly, I would rather use vectorization, but I am following gemm_tn from sgemm_80.cu which uses 1x1 val layout.

I know I can vectorize that layout by changing TA to uint128_t and layout to Layout<_1, _4>. I will experiment and see what happens, thanks again!

thakkarV commented 2 days ago

For vectorization "how to" you can follow the other issue I linked