Open manupak opened 4 days ago
Following is how the generated markdown looks like now:
A layout representing a mapping from GPU thread hierarchy to a shape
Syntax:
#iree_vector_ext.nested_layout<
::llvm::ArrayRef<int64_t>, # subgroupTile
::llvm::ArrayRef<int64_t>, # batchTile
::llvm::ArrayRef<int64_t>, # outerTile
::llvm::ArrayRef<int64_t>, # threadTile
::llvm::ArrayRef<int64_t>, # elementTile
::llvm::ArrayRef<int64_t>, # subgroupStrides
::llvm::ArrayRef<int64_t> # threadStrides
>
This layout explicitly defines how the shape of the associated vector is mapped to a compute hierarchy. We consider the following levels of hierarchy, inspired by GPUs:
Note that elements in a thread is also conceptually viewed as
a 3 dimensions. i.e. elements per thread = batch x outer x element
However, the final order of sub-dimensions are not exactly in that
hierarchy. For e.g. a single dimensional vector say vector< n x f16>
is viewed as a
vector<subgroup x batch x outer x thread x element>
5 dimensional
vector. For a two dimensional vector, each above sub-dimension would
be doubled. i.e. vector< n1 x n2 x f16>
is viewed as a
vector<subgroup1 x subgroup2 x batch1 x batch2 x ... x element1 x element2>
Now, when the vectorthread
are not directly refferring
to the subgroup_id and thread_id in the GPU context. lets define them
as virtual_subgroup_id and virtual_thread_id and they hold the following
definition:
virtual_subgroup_id[i] = (subgroup_id / subgroup_stride[i]) % subgroup_tile_size[i]
virtual_thread_id[i] = (thread_id / thread_stride[i]) % thread_tile_size[i]
the inverse mapping would be:
subgroup_id = sum_i(subgroup_stride[i] * virtual_subgroup_id[i]) % mul_i(subgroup_tile_size[i])
thread_id = sum_i(thread_stride[i] * virtual_thread_id[i]) % mul_i(thread_tile_size[i])
for i = [0 : rank(undistributed_vector)]
NOTE: if stride is zero, it represents non-distribution of that dimension on that hierarchy.
We now describe each level of tiling. Each level of tiling represents a count of tiles over the next level (rather than a list of tile sizes).
This level of tiling is also known as "subgroup/warp distribution". It represents how the vector is distributed into subgroups.
For example, consider distributing vector<4x2xf16>
to a
subgroup_tile=[4, 2], subgroup_stride=[1, 4]
will
arrange the subgroups in the order:
virtual_subgroups_ids:
[0][0] , [0][1] , [1][0], [1][1], [2][0], [2][1], [3][0], [3][1]
subgroups_ids:
0, 4, 1, 5, 2, 6, 3, 7
The total number of subgroups used (computed by multiplying each dim in subgroup_tile) should be a multiple of number of subgroups in the harware. If the total number of subgroups used exceeds the number of subgroups of the hardware, then the subgroup used (say x) is x mod num_subgroups:
num_subgroups = 4
0, 4, 1, 5, 2, 6, 3, 7
| mod 4
V
0, 0, 1, 1, 2, 2, 3, 3
This level of tiling is also known as "thread distribution" within a subgroup. The logic is quite similiar to subgroup distribution using the tile sizes and the 'thread_strides'.
So after the vector is distributed per thread on a subgroup, it is viewed as [batch] x [outer] x [element] where each sub-dimensions group has dimensions equal to original rank of the undistributed vector.
The first level, batches, are a way to represent instruction unrolling. For example, an intrinsic which can only take 4x4 shape at a time, uses batches to unroll a 16x16 shape to the native intrinsice shape.
The second level, outers, is a way to represent thread layout duplication required by a particular intrinsic. For example, some AMDGPU matrix multiplication variants require threads to be distributed like:
E.g.: outer_tile=[2, 1], thread_tile=[2, 5]
the thread Layout of shape 2x5 duplicated 2 times, to get a layout of shape 4x5
outer = 0,0 :
[0 1 2 3 4]
[5 6 7 8 9]
outer = 1,0 :
[0 1 2 3 4]
[5 6 7 8 9]
outer_tile
represents the number of outers in a batch.
The final level of tiling, representing the minimum shape of vector that is treated as an atom.
element_tile
represents the native size of the vector.
Vector to be distributed: vector<64x64xf16>
NestedLayout : <
subgroup_tile = [2, 1],
batch_tile = [2, 4],
outer_tile = [1, 1],
thread_tile = [16, 4],
element_tile = [1, 4],
subgroup_strides = [1, 0],
thread_strides = [1, 16]
>
This is conceptually viewed as a: vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>
where the first groups of sub-dimensions
represent the distribution into subgroups.
The subgroup_strides being [1, 0] means
each subgroup is going to get a vector
as follows:
subgroup0 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[0,:,:,:,:,:,:,:,:,:]
subgroup1 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[1,:,:,:,:,:,:,:,:,:]
subgroup2 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[0,:,:,:,:,:,:,:,:,:]
subgroup3 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[1,:,:,:,:,:,:,:,:,:]
Then each vector<[2x4]x[1x1]x[16x4]x[1x4]> is distributed threads in a subgroup using thread_strides = [1, 16]
recall: thread_id = sum_i(thread_stride[i] * virtual_thread_id[i]) % mul_i(thread_tile_size[i])
thread0 : vector<[2x4]x[1x1]x[1x4]>
from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,0,0,:,:]
thread1 : vector<[2x4]x[1x1]x[1x4]>
from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,1,0,:,:]
...
...
thread16 : vector<[2x4]x[1x1]x[1x4]>
from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,0,1,:,:]
Finally we are left with a distributed vector
of conceptual view : vector<[2x4]x[1x1]x[1x4]>
where the actual shape is : vector<2x16>
.
This commit updates the NestedLayoutAttr docs to represent what it is today and adds few examples to make it more understandable.