NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.49k stars 936 forks source link

[QST] SmemCopyAtom and MMA_Atom for fp32? #1842

Open vickyandpiggy opened 2 weeks ago

vickyandpiggy commented 2 weeks ago

What is your question? hello, I am developing a full precision attention backward kernel using cutlass, and get stuck in the use of ldmatrix and mma instructions for fp32.

My Gemm calculation is based on fp32 matrix, i.e. the datatype of D/A/B/C are all fp32. But the structs providied in mma_sm80.hpp take half-precision/mixed precision inputs so I am pretty confused about how to do things right in full precision. Here is my current setting for MMA, smem and gmem. Is there a way to use SM75_U32x4_LDSM_N and one of the mma instructions in my case?

  // MMA
  using TiledMma = TiledMMA<MMA_Atom<UniversalFMA<float, float, float>>, Layout<Shape<Int<16>, Int<8>, _1>>>;

  // Smem
  using SmemLayoutAtom = decltype(
    composition(Swizzle<3,3,3>{},
                Layout<Shape < _16,_32>,
                       Stride<_32, _1>>{}));
  using SmemCopyAtom = Copy_Atom<DefaultCopy, float>;

  // Gmem
  using GmemTiledCopy = decltype(
    make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, float>{},
                    Layout<Shape <_16,_8>,
                           Stride< _8,_1>>{},
                    Layout<Shape < _1,_4>>{}));
vickyandpiggy commented 5 days ago

Could anyone help? Many thanks!

vickyandpiggy commented 5 days ago

It seems to me that mma instructions does not support fp32 for Multiplicand A/B from https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types. So can i use ldmatrix alone to accelerate the copying from smem to register? Or is there better practice for full precision?

cloudhan commented 4 days ago

The NVIDIA tensor core does not natively support A/B with fp32 inputs. So it is not possible.

alternatives are:

  1. use .tf32 version with reduced precision.
  2. if reduced precision is not bearable, try https://arxiv.org/pdf/2203.03341
  3. try AMD cards

ldmatrix instruction is also tightly coupled with tensor core fragment layout, maybe you can use .tf32 version if the layout match. But they may not speedup your matrix loading, you need to take care of L2 friendly prefecting and eliminating bank conflicts manually. It is all about how to correctly and preformantly use cute.

vickyandpiggy commented 4 days ago

The NVIDIA tensor core does not natively support A/B with fp32 inputs. So it is not possible.

alternatives are:

  1. use .tf32 version with reduced precision.
  2. if reduced precision is not bearable, try https://arxiv.org/pdf/2203.03341
  3. try AMD cards

ldmatrix instruction is also tightly coupled with tensor core fragment layout, maybe you can use .tf32 version if the layout match. But they may not speedup your matrix loading, you need to take care of L2 friendly prefecting and eliminating bank conflicts manually. It is all about how to correctly and preformantly use cute.

Thank you so much for the suggestions! I got it :)

thakkarV commented 4 days ago

ldmatrix instruction is also tightly coupled with tensor core fragment layout

This is not necessarily true. You can in principle copy to arbitrary layouts using LDSM provided the partitioning is valid.

The NVIDIA tensor core does not natively support A/B with fp32 inputs

This is also irrelevant. @vickyandpiggy is trying to use SIMT cores for the matmul itself. In this case, you can still totally use LDSM provided the smem layout is legal to partition with LDSM.

thakkarV commented 4 days ago

@vickyandpiggy Please do not be discouraged.

Is there a way to use SM75_U32x4_LDSM_N and one of the mma instructions in my case?

What have you tried? what does the kernel look like so far? btw, for SIMT tensor cores, the throughput is low enough that it should not matter whether you use ld.shared or ld.matrix. You should still be able to achieve peak throughput