Closed FdyCN closed 11 months ago
Thank you, @FdyCN, for your question.
One possible approach to achieve this is by using the foreach_ij function, as demonstrated in the code snippet below:
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b;
nvcuda::wmma::fill_fragment(frag_b, __float2half(0.f));
__shared__ half matrix[16 * 2];
mtk::wmma::foreach_ij<decltype(frag_b)>(
[&](const unsigned* frag_index_list,
const unsigned fragment_index_count,
const unsigned i,
const unsigned j) {
for (unsigned f = 0; f < fragment_index_count; f++)
if (j < 2) {
frag_b.x[frag_index_list[f]] = matrix[j * 16 + i];
}
});
Ideally, it would be beneficial to have a template parameter in the load_matrix_sync
function to specify the actual matrix size (e.g., 16x2) directly to reduce the number of operations. However, this library currently does not provide such a feature.
Let me know if you have any other or additional questions.
Thanks.
Thank you, @FdyCN, for your question.
One possible approach to achieve this is by using the foreach_ij function, as demonstrated in the code snippet below:
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> frag_b; nvcuda::wmma::fill_fragment(frag_b, __float2half(0.f)); __shared__ half matrix[16 * 2]; mtk::wmma::foreach_ij<decltype(frag_b)>( [&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned i, const unsigned j) { for (unsigned f = 0; f < fragment_index_count; f++) if (j < 2) { frag_b.x[frag_index_list[f]] = matrix[j * 16 + i]; } });
Ideally, it would be beneficial to have a template parameter in the
load_matrix_sync
function to specify the actual matrix size (e.g., 16x2) directly to reduce the number of operations. However, this library currently does not provide such a feature.Let me know if you have any other or additional questions.
Thanks.
@enp1s0 thank you for the reply. I will try what you said.
The work you've done is really helpful for me. Thank you.
And I have a question: Is this extention supported for loading a matrix with smaller size(e.g. 2x16) into fragment(with size 16x16)? Which means pad 0 to the rest position in fragment?