iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 621 forks source link

[DOCS] Update VectorExt::NestedLayoutAttr docs #19246

Open manupak opened 4 days ago

manupak commented 4 days ago

This commit updates the NestedLayoutAttr docs to represent what it is today and adds few examples to make it more understandable.

manupak commented 4 days ago

Following is how the generated markdown looks like now:

NestedLayoutAttr

A layout representing a mapping from GPU thread hierarchy to a shape

Syntax:

#iree_vector_ext.nested_layout<
  ::llvm::ArrayRef<int64_t>,   # subgroupTile
  ::llvm::ArrayRef<int64_t>,   # batchTile
  ::llvm::ArrayRef<int64_t>,   # outerTile
  ::llvm::ArrayRef<int64_t>,   # threadTile
  ::llvm::ArrayRef<int64_t>,   # elementTile
  ::llvm::ArrayRef<int64_t>,   # subgroupStrides
  ::llvm::ArrayRef<int64_t>   # threadStrides
>

This layout explicitly defines how the shape of the associated vector is mapped to a compute hierarchy. We consider the following levels of hierarchy, inspired by GPUs:

  1. Subgroups per workgroup
  2. Threads per subgroup
  3. Elements per thread

Note that elements in a thread is also conceptually viewed as a 3 dimensions. i.e. elements per thread = batch x outer x element However, the final order of sub-dimensions are not exactly in that hierarchy. For e.g. a single dimensional vector say vector< n x f16> is viewed as a vector<subgroup x batch x outer x thread x element> 5 dimensional vector. For a two dimensional vector, each above sub-dimension would be doubled. i.e. vector< n1 x n2 x f16> is viewed as a vector<subgroup1 x subgroup2 x batch1 x batch2 x ... x element1 x element2>

Now, when the vector is indexed, the indices of 'subgroup' and thread are not directly refferring to the subgroup_id and thread_id in the GPU context. lets define them as virtual_subgroup_id and virtual_thread_id and they hold the following definition:

virtual_subgroup_id[i] = (subgroup_id / subgroup_stride[i]) % subgroup_tile_size[i]
virtual_thread_id[i]   = (thread_id   / thread_stride[i]) % thread_tile_size[i]

the inverse mapping would be:

subgroup_id = sum_i(subgroup_stride[i] * virtual_subgroup_id[i]) % mul_i(subgroup_tile_size[i])
thread_id = sum_i(thread_stride[i] * virtual_thread_id[i]) % mul_i(thread_tile_size[i])
    for i = [0 : rank(undistributed_vector)]

NOTE: if stride is zero, it represents non-distribution of that dimension on that hierarchy.

We now describe each level of tiling. Each level of tiling represents a count of tiles over the next level (rather than a list of tile sizes).

Subgroups per Workgroup

This level of tiling is also known as "subgroup/warp distribution". It represents how the vector is distributed into subgroups.

For example, consider distributing vector<4x2xf16> to a subgroup_tile=[4, 2], subgroup_stride=[1, 4] will arrange the subgroups in the order:

virtual_subgroups_ids:
[0][0] , [0][1] , [1][0], [1][1], [2][0], [2][1], [3][0], [3][1]
subgroups_ids:
0, 4, 1, 5, 2, 6, 3, 7

The total number of subgroups used (computed by multiplying each dim in subgroup_tile) should be a multiple of number of subgroups in the harware. If the total number of subgroups used exceeds the number of subgroups of the hardware, then the subgroup used (say x) is x mod num_subgroups:

num_subgroups = 4

0, 4, 1, 5, 2, 6, 3, 7
| mod 4
V
0, 0, 1, 1, 2, 2, 3, 3

Threads per Subgroup:

This level of tiling is also known as "thread distribution" within a subgroup. The logic is quite similiar to subgroup distribution using the tile sizes and the 'thread_strides'.

Element distribution on a thread

So after the vector is distributed per thread on a subgroup, it is viewed as [batch] x [outer] x [element] where each sub-dimensions group has dimensions equal to original rank of the undistributed vector.

The first level, batches, are a way to represent instruction unrolling. For example, an intrinsic which can only take 4x4 shape at a time, uses batches to unroll a 16x16 shape to the native intrinsice shape.

The second level, outers, is a way to represent thread layout duplication required by a particular intrinsic. For example, some AMDGPU matrix multiplication variants require threads to be distributed like:

E.g.: outer_tile=[2, 1], thread_tile=[2, 5] the thread Layout of shape 2x5 duplicated 2 times, to get a layout of shape 4x5

outer = 0,0 :
[0 1 2 3 4]
[5 6 7 8 9]

outer = 1,0 :
[0 1 2 3 4]
[5 6 7 8 9]

outer_tile represents the number of outers in a batch.

The final level of tiling, representing the minimum shape of vector that is treated as an atom.

element_tile represents the native size of the vector.

A full example

Vector to be distributed: vector<64x64xf16>

NestedLayout : <
    subgroup_tile = [2, 1],
    batch_tile = [2, 4],
    outer_tile = [1, 1],
    thread_tile = [16, 4],
    element_tile = [1, 4],
    subgroup_strides = [1, 0],
    thread_strides = [1, 16]
>

This is conceptually viewed as a: vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]> where the first groups of sub-dimensions represent the distribution into subgroups. The subgroup_strides being [1, 0] means each subgroup is going to get a vector as follows:

subgroup0 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[0,:,:,:,:,:,:,:,:,:]
subgroup1 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[1,:,:,:,:,:,:,:,:,:]
subgroup2 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[0,:,:,:,:,:,:,:,:,:]
subgroup3 : vector<[2x4]x[1x1]x[16x4]x[1x4]>
from vector<[2x1]x[2x4]x[1x1]x[16x4]x[1x4]>[1,:,:,:,:,:,:,:,:,:]

Then each vector<[2x4]x[1x1]x[16x4]x[1x4]> is distributed threads in a subgroup using thread_strides = [1, 16]

recall: thread_id = sum_i(thread_stride[i] * virtual_thread_id[i]) % mul_i(thread_tile_size[i])

thread0 : vector<[2x4]x[1x1]x[1x4]>
from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,0,0,:,:]
thread1 : vector<[2x4]x[1x1]x[1x4]>
from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,1,0,:,:]
...
...
thread16 : vector<[2x4]x[1x1]x[1x4]>
from vector<[2x4]x[1x1]x[16x4]x[1x4]>[:,:,:,:,0,1,:,:]

Finally we are left with a distributed vector of conceptual view : vector<[2x4]x[1x1]x[1x4]> where the actual shape is : vector<2x16>.