wmmae / wmma_extension

An extension library of WMMA API (Tensor Core API)
https://arxiv.org/abs/2308.15152
MIT License
82 stars 14 forks source link
gpu gpu-computing gpu-programming matrix tensorcore tensorcores wmma-api

WMMA API Extension

This extension provides features for

without using extra shared memory.

[!IMPORTANT] Please specify an appropriate virtual architecture for real GPU. For instance, a program which is compiled with -arch=sm_70 will not work correctly on Ampere GPUs.

Requirements

Supported architectures / fragment

Functions

Primitive functions

foreach

This function calculates the mapping of the memory and fragment elements.

nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
__shared__ compute_t matrix[16 * 16];
mtk::wmma::foreach<decltype(frag_b)>(
        [&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) {
            const auto m = mem_index % 16;
            const auto n = mem_index / 16;
            for (unsigned i = 0; i < fragment_index_count; i++)
                frag_b.x[frag_index_list[i]] = convert_to<half>(matrix[n * 16 + m]);
        });

foreach_ij

This function calculates the mapping of the matrix element position (i,j) and fragment elements.

nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
__shared__ compute_t matrix[16 * 16];
mtk::wmma::foreach_ij<decltype(frag_b)>(
        [&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned i, const unsigned j) {
            for (unsigned f = 0; f < fragment_index_count; f++)
                frag_b.x[frag_index_list[f]] = convert_to<half>(matrix[j * 16 + i]);
        });

foreach_v

For matrix A/B

This function calculates the mapping of a given vector and fragment elements.

nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
__shared__ compute_t vector[16];
mtk::wmma::foreach_v<decltype(frag_b)>(
        [&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) {
            for (unsigned i = 0; i < fragment_index_count; i++)
                frag_b.x[frag_index_list[i]] = convert_to<half>(vector[mem_index]);
        });
// is equivalent to `load_vector`

For accumulator

nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, float> frag_c;
__shared__ compute_t vector[16];
mtk::wmma::foreach_v<decltype(frag_c)>(nvcuda::wmma::mem_col_major,
        [&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) {
            for (unsigned i = 0; i < fragment_index_count; i++)
                vector[mem_index] = convert_to<compute_t>(frag_c.x[frag_index_list[i]]);
        });
// is equivalent to `store_vector`

map

This function returns the mapping of matrix element (i, j) and fragment element (tid, fid)

nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
unsigned tid_list[2];
unsigned fid_list[2];
unsigned list_size;
mtk::wmma::map<decltype(frag_b)>(tid_list, fid_list, list_size, i, j);
for (unsigned k = 0; k < list_size; k++) {
  if ((threadIdx.x & 0x1f) == tid_list[k]) {
    frag_b.x[fid_list[k]] = 3.0f;
  }
}

Functions for vector

Sample

#include <mma.h>
#include <wmma_extension/wmma_extension.hpp>

__global__ void kernel() {
    nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> frag_a;
    nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
    nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, float> frag_c;

    __shared__ float vec16[16];

    mtk::wmma::load_vector(frag_a, vec16);
    mtk::wmma::load_vector(frag_b, vec16);

    nvcuda::wmma::fill_fragment(frag_c, 0.0f);
    nvcuda::wmma::mma_sync(frag_c, frag_a, frag_b, frag_c);

    mtk::wmma::store_vector(vec16, frag_c, nvcuda::wmma::mem_col_major);
}

Other functions

make_identity_matrix / add_eye

load_matrix

fill_zero

Debugging functions

print_fragment

This function output the elements of a fragment.

Publication

@inproceedings{ootomo_wmmae_2023,
  author = {Ootomo, Hiroyuki and Yokota, Rio},
  title = {Reducing Shared Memory Footprint to Leverage High Throughput on Tensor Cores and Its Flexible API Extension Library},
  year = {2023},
  series = {HPC Asia '23}
}

LICENSE

MIT