Open liuqi123123 opened 3 months ago
@ccecka thanks!
Going to have to be more specific about what you have and what you want -- want to show actual indices rather than a
and b
placeholders.
In general, we manipulate the Layout
first to order the elements how we want (into 2x2 chunks, for example) then we can think about swizzling within the chunks or between the chunks or swizzling the chunks themselves.
thanks for your reply! I will describe in detail. To simplify the problem, we assume there are only 4 banks in smem, and each bank is only 4 bytes wide, and we assume there are only 4 threads in a warp.
now, we have a smem, layout is ((4, 4), (4, 1)):
// 0 1 2 3
// +----+----+----+----+
// 0 | 0 | 1 | 2 | 3 |
// +----+----+----+----+
// 1 | 4 | 5 | 6 | 7 |
// +----+----+----+----+
// 2 | 8 | 9 | 10 | 11 |
// +----+----+----+----+
// 3 | 12 | 13 | 14 | 15 |
// +----+----+----+----+
then, I encountered a situation: I must have to use the 4 threads in a warp to get data form smem like this:
phase 1:
// 0 1 2 3
// +----+----+----+----+
// 0 | T0 | T2 | x | x |
// +----+----+----+----+
// 1 | x | x | x | x |
// +----+----+----+----+
// 2 | T1 | T3 | x | x |
// +----+----+----+----+
// 3 | x | x | x | x |
// +----+----+----+----+
phase 2:
// 0 1 2 3
// +----+----+----+----+
// 0 | x | x | x | x |
// +----+----+----+----+
// 1 | T0 | T2 | x | x |
// +----+----+----+----+
// 2 | x | x | x | x |
// +----+----+----+----+
// 3 | T1 | T3 | x | x |
// +----+----+----+----+
phase 3:
// 0 1 2 3
// +----+----+----+----+
// 0 | x | x | T0 | T2 |
// +----+----+----+----+
// 1 | x | x | x | x |
// +----+----+----+----+
// 2 | x | x | T1 | T3 |
// +----+----+----+----+
// 3 | x | x | x | x |
// +----+----+----+----+
phase 4:
// 0 1 2 3
// +----+----+----+----+
// 0 | x | x | x | x |
// +----+----+----+----+
// 1 | x | x | T0 | T2 |
// +----+----+----+----+
// 2 | x | x | x | x |
// +----+----+----+----+
// 3 | x | x | T1 | T3 |
// +----+----+----+----+
It is clear, there are bank conflict every phase.
so, I want to change the data arrangement in seme:
// 0 1 2 3
// +----+----+----+----+
// 0 | 0 | 1 | 2 | 3 |
// +----+----+----+----+
// 1 | 4 | 5 | 6 | 7 |
// +----+----+----+----+
// 2 | 8 | 9 | 10 | 11 |
// +----+----+----+----+
// 3 | 12 | 13 | 14 | 15 |
// +----+----+----+----+
----------->>>
// 0 1 2 3
// +----+----+----+----+
// 0 | 0 | 1 | 2 | 3 |
// +----+----+----+----+
// 1 | 4 | 5 | 6 | 7 |
// +----+----+----+----+
// 2 | 10 | 11 | 8 | 9 |
// +----+----+----+----+
// 3 | 14 | 15 | 12 | 13 |
// +----+----+----+----+
(it just like swizzle in unit of 2x2 elements )
then , phase 1 to phase 4 will change to :
phase 1: // 0 1 2 3 // +----+----+----+----+ // 0 | T0 | T2 | x | x | // +----+----+----+----+ // 1 | x | x | x | x | // +----+----+----+----+ // 2 | x | x | T1 | T3 | // +----+----+----+----+ // 3 | x | x | x | x | // +----+----+----+----+ phase 2: // 0 1 2 3 // +----+----+----+----+ // 0 | x | x | x | x | // +----+----+----+----+ // 1 | T0 | T2 | x | x | // +----+----+----+----+ // 2 | x | x | x | x | // +----+----+----+----+ // 3 | x | x | T1 | T3 | // +----+----+----+----+
phase 3: // 0 1 2 3 // +----+----+----+----+ // 0 | x | x | T0 | T2 | // +----+----+----+----+ // 1 | x | x | x | x | // +----+----+----+----+ // 2 | T1 | T3 | x | x | // +----+----+----+----+ // 3 | x | x | x | x | // +----+----+----+----+
phase 4: // 0 1 2 3 // +----+----+----+----+ // 0 | x | x | x | x | // +----+----+----+----+ // 1 | x | x | T0 | T2 | // +----+----+----+----+ // 2 | x | x | x | x | // +----+----+----+----+ // 3 |T1 | T3 | x | x | // +----+----+----+----+
the bank conflict will disappear in all phase
but how to do this irregular swizzle?
just try:
#include <cute/tensor.hpp>
using namespace cute;
int main() {
auto a = Layout<Shape<_4, _4>, Stride<_4, _1>>{};
print_layout(a);
auto b = composition(Swizzle<1, 1, 2>{}, a);
print_layout(b);
}
(_4,_4):(_4,_1)
0 1 2 3
+----+----+----+----+
0 | 0 | 1 | 2 | 3 |
+----+----+----+----+
1 | 4 | 5 | 6 | 7 |
+----+----+----+----+
2 | 8 | 9 | 10 | 11 |
+----+----+----+----+
3 | 12 | 13 | 14 | 15 |
+----+----+----+----+
Sw<1,1,2> o _0 o (_4,_4):(_4,_1)
0 1 2 3
+----+----+----+----+
0 | 0 | 1 | 2 | 3 |
+----+----+----+----+
1 | 4 | 5 | 6 | 7 |
+----+----+----+----+
2 | 10 | 11 | 8 | 9 |
+----+----+----+----+
3 | 14 | 15 | 12 | 13 |
+----+----+----+----+
Awesome! thanks very much!
@mammoth831 , I'm curious as to why you have this magical swizzle idea. Is there any mathematical theory behind it? Or do you have any relevant reference materials that you can recommend? I would be very grateful
If you treat it as an 8-bank SRAM and each swizzling chunk size is 2, then the swizzle pattern you want is a common case for the generic swizzle functor. And the best reference is the code itself.
This issue has been labeled inactive-30d
due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d
if there is no activity in the next 60 days.
Assume that there are 4x4 elements in the shared memory. I can use composition(Swizzle<2, 0, 2>{},...) to swizzle each element successfully, But now I want to swizzle in unit of 2x2 elements, just like :
I divide 4x4 into 2x2 pieces(a, b , c, d), each pieces is unit of 2*2 elements, I want to use some methods like composition(Swizzle<1, 0, 1>{},...) to swizzle it, but each pieces is not consecutive in smem, Is there any solution?