wmmae / wmma_extension

An extension library of WMMA API (Tensor Core API)
https://arxiv.org/abs/2308.15152
MIT License
82 stars 14 forks source link

What about load_matrix? #3

Closed FdyCN closed 11 months ago

FdyCN commented 1 year ago

The work you've done is really helpful for me. Thank you.

And I have a question: Is this extention supported for loading a matrix with smaller size(e.g. 2x16) into fragment(with size 16x16)? Which means pad 0 to the rest position in fragment?

enp1s0 commented 1 year ago

Thank you, @FdyCN, for your question.

One possible approach to achieve this is by using the foreach_ij function, as demonstrated in the code snippet below:

nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
nvcuda::wmma::fill_fragment(frag_b, __float2half(0.f));

__shared__ half matrix[16 * 2];
mtk::wmma::foreach_ij<decltype(frag_b)>(
  [&](const unsigned* frag_index_list,
      const unsigned fragment_index_count,
      const unsigned i,
      const unsigned j) {
    for (unsigned f = 0; f < fragment_index_count; f++)
      if (j < 2) {
        frag_b.x[frag_index_list[f]] = matrix[j * 16 + i];
      }
  });

Ideally, it would be beneficial to have a template parameter in the load_matrix_sync function to specify the actual matrix size (e.g., 16x2) directly to reduce the number of operations. However, this library currently does not provide such a feature.

Let me know if you have any other or additional questions.

Thanks.

FdyCN commented 1 year ago

Thank you, @FdyCN, for your question.

One possible approach to achieve this is by using the foreach_ij function, as demonstrated in the code snippet below:

nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
nvcuda::wmma::fill_fragment(frag_b, __float2half(0.f));

__shared__ half matrix[16 * 2];
mtk::wmma::foreach_ij<decltype(frag_b)>(
  [&](const unsigned* frag_index_list,
      const unsigned fragment_index_count,
      const unsigned i,
      const unsigned j) {
    for (unsigned f = 0; f < fragment_index_count; f++)
      if (j < 2) {
        frag_b.x[frag_index_list[f]] = matrix[j * 16 + i];
      }
  });

Ideally, it would be beneficial to have a template parameter in the load_matrix_sync function to specify the actual matrix size (e.g., 16x2) directly to reduce the number of operations. However, this library currently does not provide such a feature.

Let me know if you have any other or additional questions.

Thanks.

@enp1s0 thank you for the reply. I will try what you said.