Closed Bruce-WangGF closed 4 months ago
// 计算每个线程的结果
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
// 我们将点积循环放在外循环中,这有助于重用Bs,我们可以将其缓存在tmp变量中。
float Btmp = Bs[dotIdx * BN + threadCol];
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
threadResults[resIdx] +=
As[(threadRow * TM + resIdx) * BK + dotIdx] * Btmp;
}
}
dotIdx 递增, Bs[dotIdx * BN + threadCol] 的索引是按列移动的呀,这里没有问题
你可以打印一下具体的索引值看一下
As 在上面的循环里面也不是按照列移动的,后面是 +dotIdx 按照行移动
我是这么理解的,在例子中,As是64行8列,Bs是8行64列。通过bk_idx,也就是外层循环逐步往前推,计算C小块的值。 float Btmp = Bs[dotIdx BN + threadCol]; 这句代码,在例子中,threadCol 是64个线程并行控制,for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) 这个for循环相当于串行控制,如果Bs是8行64列,那dotIdx BN不就是每次跳过一行吗?threadCol为64个线程控制每个元素,是我哪里理解错了吗
我是这么理解的,在例子中,As是64行8列,Bs是8行64列。通过bk_idx,也就是外层循环逐步往前推,计算C小块的值。 float Btmp = Bs[dotIdx BN + threadCol]; 这句代码,在例子中,threadCol 是64个线程并行控制,for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) 这个for循环相当于串行控制,如果Bs是8行64列,那dotIdx BN不就是每次跳过一行吗?threadCol为64个线程控制每个元素,是我哪里理解错了吗
每次跳过一行不就是按列取吗,和图对的上吧,看下图上的箭头
想明白了,谢谢😂
你好,我看了你的Thread Tiling那块的代码,有些地方不太理解。
// 计算每个线程的结果 for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { // 我们将点积循环放在外循环中,这有助于重用Bs,我们可以将其缓存在tmp变量中。 float Btmp = Bs[dotIdx * BN + threadCol]; for (uint resIdx = 0; resIdx < TM; ++resIdx) { threadResults[resIdx] += As[(threadRow * TM + resIdx) * BK + dotIdx] * Btmp; } }
这段代码我不太能理解,尤其是Btmp在Bs中的索引是*BN的,也就是按行索引的,threadCol也是[0-63]计算整行,As是乘TM,按列索引,这是否和图示的As按行Bs按列不太符合,还是我哪里理解的不对?