wmmae / wmma_extension

An extension library of WMMA API (Tensor Core API)
https://arxiv.org/abs/2308.15152
MIT License
82 stars 14 forks source link

Unexpected Performance Regression #4

Open elvircrn opened 10 months ago

elvircrn commented 10 months ago

Hi, I hope that this repo is still maintained or at least open for questions. :)

My use-case:

I have a code-base which utilizes the C++ wmma template API. For specific reasons I need to perform the m16n16k16 operations and I wanted to find a quick way to switch it out for the newer mma instructions using your awesome library in the following way:

#define USE_MMA
#ifdef USE_MMA
#include "wmma_extension/detail/common.hpp"
#include "wmma_extension/tcec/detail/simt/mma_simt.hpp"
#else
#include <mma.h>
#endif

// ..

#ifdef USE_MMA
  namespace w = mtk::wmma::mma_simt;
  namespace wmma = nvcuda::wmma;
#else
  namespace w = nvcuda::wmma;
  namespace wmma = nvcuda::wmma;
#endif

and then later on use the namespaces as follows:

w::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, VALUE_TYPE, wmma::row_major>
...
w::fill_fragment(frags[i][j], (VALUE_TYPE)0);
// etc.

Basically the USE_MMA flag switches between this library and the built-in nvcuda wmma C++ API.

I have two question:

I can post more details, if necessary, of course (in case I didn't horribly missuse/missunderstand the library internals).

Thanks, Elvir

elvircrn commented 10 months ago

Whoops, seems like this code path doesn't invoke mma instructions (please correct me if I'm wrong). I guess I could get away with just having multiple m16n8k16, but I am still wondering if there is a more convinient way of doing this?

enp1s0 commented 10 months ago

Hi @elvircrn, thank you for using this library. As you mentioned, mtk::wmma::mma_simt does not invoke the MMA instruction. Instead, there are two ways to use the m16n8k16 instruction to compute m16n16k16.

  1. As you mentioned, invoke the m16n8k16 instruction twice in your code.
  2. Use the mtk::wmma::tcec API as shown in the code below. This way is available only when the fragment type of matrix C and D is float.
    
    #include <wmma_extension/tcec/tcec.hpp>

// Policy: use mma.m16n8.k16 without the single-precision emulation using mma_policy = mtk::wmma::tcec::Policy<mtk::wmma::tcec::op_mma , mtk::wmma::tcec::without_ec, 16, 8 , 16>;

constexpr std::uint32_t N = 16; mtk::wmma::tcec::fragment<nvcuda::wmma::matrix_a, N, N, N, half, nvcuda::wmma::col_major, mma_policy> frag_a; mtk::wmma::tcec::fragment<nvcuda::wmma::matrix_b, N, N, N, half, nvcuda::wmma::col_major, mma_policy> frag_b; mtk::wmma::tcec::fragment<nvcuda::wmma::accumulator, N, N, N, half, void, mma_policy> frag_c, frag_d; // Specify 'half' here, even though the actual data type is 'float'

//... mtk::wmma::tcec::load_matrix_sync(frag_a, smem, N);

//... mtk::wmma::tcec::mma_sync(frag_d, frag_a, frag_b, frag_c);

//...



The `mtk::wmma::tcec` API is for emulating single-precision matrix multiplication on Tensor Cores. But, you can use it for your purpose this time by changing the `Policy` of the fragments to disable single-precision emulation mode. In this example code, the `mtk::wmma::tcec::mma_sync` function invokes the `mma.m16.n8.k16` instruction twice internally.

Feel free to let me know if you have any questions.

Thanks