triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.1k stars 1.6k forks source link

Mistakes in `class DistributedEncoding`'s illustration #4309

Open Shoreshen opened 3 months ago

Shoreshen commented 3 months ago

In the definition of class DistributedEncoding in file nclude/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td there is a illustration comments:

  let description = [{
Distributed encodings have a layout function L that is entirely characterized
by a d-dimensional tensor T. Note that L doesn't need to have the same shape
(or even the same rank) as the tensor it is encoding.

The layout function \mathcal{L} of this layout is then defined, for an
index `i` \in Z^d, as follows:

\mathcal{L}(T)[i_d] = L[(i_d + k_d*T.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*T.shape[d] < L.shape[d]

Intuitively, when the tensor dim size T.shape[d] is larger than the layout
dim size L.shape[d], on that particular dim, we distribute values from the
tensor to threads mapped in the layout in a "wrapped around" manner, with
each thread owning multiple values.

OTOH, when the tensor dim size T.shape[d] is smaller than the layout
dim size L.shape[d], on that particular dim, we distribute values from the
tensor to threads mapped in the layout in a "broadcasted" manner, with
each value owned by multiple threads.

For example, for a tensor/layout pair
T = [x  x  x  x  x  x  x  x]
    [x  x  x  x  x  x  x  x]
L = [0  1  2  3 ]
    [4  5  6  7 ]
    [8  9  10 11]
    [12 13 14 15]

Then the data of T would be distributed as follow between the 16 CUDA threads:
L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
         {4,12}, {5,13}, {6,14}, {7,15}, {4,12}, {5, 13}, {6, 14}, {7, 15} ]
  }];

Based on my understanding:

  1. If rank here refers to the number of tensor and CTA dimensions (L dimension??), then it requires that the dimension of tensor has to be smaller or equal to the dimension of CTA
  2. Lets say we calculate the distribute of T[0,5], We have T.shape[0]=2, T.shape[1]=8, L.shape[0]=4, L.shape[1]=4then
    1. For dimension 0:
      1. k_0=0: we have 0 + 0 * 2 = 0 < 4, accept
      2. k_0=1: we have 0 + 1 * 2 = 2 < 4, accept
      3. k_0=2: we have 0 + 2 * 2 = 4 = 4, reject and all further
    2. For dimension 1:
      1. k_1=0: we have 5 + 0 * 2 = 5 > 4, reject and all further
    3. As summary, there is no available thread inside this CTA gonna save T[0,5]
feiyuvl commented 2 months ago
\mathcal{L}(T)[i_d] = L[(i_d + k_d*T.shape[d]) % L.shape[d]] \forall k_d such as i_d + k_d*T.shape[d] < L.shape[d]

Yes, the formula is not very accurate. Maybe it can be fixed by

forall  k_d such as k_d = 0 or i_d + k_d*T.shape[d] < L.shape[d]