openmlsys / openmlsys-zh

《Machine Learning Systems: Design and Implementation》- Chinese Version
https://openmlsys.github.io
3.93k stars 426 forks source link

8.4 章节中 gemm_use_tile.cu 的代码问题 #458

Open hiha3456 opened 1 year ago

hiha3456 commented 1 year ago

这一部分有关三个 Layout 的代码我一直没有看明白,在这篇知乎中看到一样的内容后,我发现这里的代码实现可能有 bug。 按照定义:

那么此处 gemm_use_tile.cu 第10行和第11行 中对于m 和 n 的定义就有问题了,应该如下:

unsigned m= threadIdx.x* LayoutTile::m/LayoutBlock::m+ LayoutTile::m* blockIdx.x;
unsigned n= threadIdx.y* LayoutTile::n/LayoutBlock::n+ LayoutTile::n* blockIdx.y

同样的, gemm_use_tile.cu 第19行和第20行 中,iterationA 和 iterationB 应该分别指的是每个 thread 有多少个 (4,4) 的 subMatrix,这里应该是 2*2 = 4 个,那么 gemm_use_tile.cu 第21行和第22行 intervalA 和 intervalB 的定义就有问题了,按照后续代码,intervalA 和 intervalB 指的分别应该是每个 subMatrix 有多大,也就是 (LayoutThread::m, LayoutThread::n)