diku-dk / futhark

:boom::computer::boom: A data-parallel functional programming language
http://futhark-lang.org
ISC License
2.41k stars 166 forks source link

Spurious variance in tensor contraction expression going into loop tiling #2125

Closed sortraev closed 8 months ago

sortraev commented 8 months ago

Hello! Apologies for the long post. As part of my master's thesis project, I'm looking into how the block/register tiling passes of the loop tiling stage can be generalized to handle arbitrary tensor contraction patterns.

One condition for the current implementation to fire is that each operand array to the inner redomap is invariant to exactly (k-1) of the k innermost parallel dimensions surrounding the redomap. As an example, for the multiplication of matrices xss and yss:

map (\xs ->    -- dim 0
  map (\ys ->  -- dim 1
    map2 (*) xs ys |> f32.sum
  ) (transpose yss)
) xss

we have that xs and ys are each invariant to 1 of the parallel dimensions on top of the redomap. This presents an opportunity for reuse (via tiling).

I haven't exactly derived the variance rules for tiling a tensor contraction, nor the restrictions thereon, so the answer to the question posed here may potentially turn out to be not important. However it is still a point of confusion for me, so I hope someone can help.

I have tried to minimize the below code example, but my problem arises only for certain dimensionalities and permutations of dimensions, and this was the smallest example I found that provokes the issue.

Consider below Futhark source program, which implements the contraction Z[i, b, j, a] = X[q, a, i] * Y[b, q, j]:

def example_tc5
  [A][B][I][J]
  [Q]
  (xsss: [Q][A][I]f32)
  (ysss: [B][Q][J]f32)
  : [I][B][J][A]f32 =

  #[unsafe]
  map (\i ->        -- dim 0
    map (\b ->      -- dim 1
      map (\j ->    -- dim 2
        map (\a ->  -- dim 3
          map2 (*) xsss[:, a, i] ysss[b, :, j] |> f32.sum
        ) (iota A)
      ) (iota J)
    ) (iota B)
  ) (iota I)

Note that the same program can be expressed without iota and explicit indexing, which may or may not change the IR (and hence this question) significantly, however this requires manual and tedious permutation.

From the source program we expect that xsss[:, a, i] would be variant to dims 0 and 3, and ysss[b, :, j] variant to dims 1 and 2.

The IR going into loop tiling is:

let {map2_arg2_8307 : [Q₄_6857]f32} =
  map2_arg2_r_r_r_8279[gtid_8301, gtid_8302, gtid_8303, 0i64 :+ Q₄_6857 * 1i64]
let {index_8309 : [Q₄_6857]f32} =
  xsss_6860[0i64 :+ Q₄_6857 * 1i64, gtid_8304, gtid_8301]
let {defunc_res_8310 : f32} =
  redomap(D₄_6857,
          {index_8309, map2_arg2_8307},
          {\ {eta_p_8311 : f32,
              eta_p_8312 : f32}
            : {f32} ->
            let {+_res_8313 : f32} =
              fadd32(eta_p_8311, eta_p_8312)
            in {+_res_8313},
          {0.0f32}},
          \ {eta_p_8314 : f32,
             eta_p_8315 : f32}
            : {f32} ->
            let {defunc_0_f_res_8316 : f32} =
              fmul32(eta_p_8314, eta_p_8315)
            in {defunc_0_f_res_8316})

With segspace:

(gtid_8301 < I₂_6856, -- dim 0
 gtid_8302 < B₁_6858, -- dim 1
 gtid_8303 < J₃_6859, -- dim 2
 gtid_8304 < A₀_6855  -- dim 3
) (~phys_tid_8305)

Finally, the input arrays to the redomap are {index_8309, map2_arg2_8307}. Looking at the above IR, we see that index_8309 is a slice of the 3-dimensional xss_6797 and that this slice depends on gtid_8238 and gtid_8241 which correspond to dims 0 and 3; however, map2_arg2_8307 is a slice of the 4-dimensional map2_arg2_r_r_r_8279 and this slice depends on gtid_8301, gtid_8302, and gtid_8303, which correspond to dims 0, 1, and 2 !

This is verified by inspection of the variance in the kernel stms:

variance: fromList [
  ...
  (VName (Name "map2_arg2") 8307,Names (fromList [(6857,VName (Name "D\8324") 6857),(8279,VName (Name "map2_arg2_r_r_r") 8279),(8301,VName (Name "gtid") 8301),(8302,VName (Name "gtid") 8302),(8303,VName (Name "gtid") 8303)])),
  (VName (Name "index") 8309,Names (fromList [(6857,VName (Name "D\8324") 6857),(6860,VName (Name "xsss") 6860),(8301,VName (Name "gtid") 8301),(8304,VName (Name "gtid") 8304)])),
  ...
]

Immediately I'd say this variance on dim 0 from the second operand to the redomap is unwelcome, at least in terms of variance analysis, but I could be wrong.

Anyway, my question is this: whence comes map2_arg2_r_r_r_8279 and why, and in particular, why is it variant on a dimension of the other operand?

EDIT: removed the positive example as it was irrelevant to the question.

EDIT2: As mentioned, the example Futhark program discussed above can be expressed without iotas and explicit indexing. This can for example look as such:

def example_tc6
  [A][B][I][J]
  [Q]
  (xsss': [Q][A][I]f32)
  (ysss': [B][Q][J]f32)
  : [I][B][J][A]f32 =
  let xsss = map flatten xsss' |> transpose |> unflatten |> transpose
  let ysss = map transpose ysss'
  in map (\xss ->
       map (\yss ->
         map (\ys ->
           map (\xs ->
             map2 (*) xs ys |> f32.sum
           ) xss
         ) yss
       ) ysss
     ) xsss

And as it turns out, when expressed as such, the IR and segspace is:

let {as_transformed_row_7936 : [Q₄_6825]f32} =
  ysss'_transformed_7504[gtid_7932, gtid_7933, 0i64 :+ Q₄_6825 * 1i64]
let {as_transformed_row_7937 : [Q₄_6825]f32} =
  reshape_7528[gtid_7934, gtid_7931, 0i64 :+ Q₄_6825 * 1i64]
let {defunc_res_7938 : f32} =
  redomap(Q₄_6825,
          {as_transformed_row_7937, as_transformed_row_7936},
          {\ {eta_p_7939 : f32,
              eta_p_7940 : f32}
            : {f32} ->
            let {+_res_7941 : f32} =
              fadd32(eta_p_7939, eta_p_7940)
            in {+_res_7941},
          {0.0f32}},
          \ {eta_p_7942 : f32,
             eta_p_7943 : f32}
            : {f32} ->
            let {defunc_0_f_res_7944 : f32} =
              fmul32(eta_p_7942, eta_p_7943)
            in {defunc_0_f_res_7944})

seg_space:
(gtid_7931 < I₂_6824, gtid_7932 < B₁_6826, gtid_7933 < J₃_6827, gtid_7934 < A₀_6823) (~phys_tid_7935)

with the expected variance on the redomap arrays as_transformed_row_7937 and as_transformed_row_7936.

athas commented 8 months ago

Iotas are simplified away entirely after the tiling pass, so they will not show up in the generated code.

sortraev commented 8 months ago

Yes, sorry, what's important in the code is not the use of iota but the use of explicit indexing. For all the example tensor contraction expressions I have tested thus far, the problem (of extra dimensions, and hence extra variance added onto the one redomap array) does not seem to occur when I rewrite the expression to not use explicit indexing, as in the example given in my EDIT2.

Or perhaps I misunderstood your reply? Can you please elaborate?

athas commented 8 months ago

If you explicitly index an iota, the iota goes away. Pre-tiling, we still have redomaps, so the iota is not explicitly indexed (at least not fully). Post-tiling, those redomaps get turned into loop with explicit indexing, and then the indexing of iota is simplified away. I'm not quite sure what the question is now.

sortraev commented 8 months ago

I think my original question remains, but allow me to simplify:

The two input arrays to the contraction have dimensions xsss : [Q][A][I] and ysss : [B][Q][J], and the result has dimensions [I][B][J][A]. Hence the segspace has dimensions [I][B][J][A], and the reduced dimension is [Q].

Given the map nest in example_tc5 (second snippet in my OP), the first input slice to the redomap should be variant only to the two maps labeled "dim 0" and "dim 3" (the I and A of the segspace), while the second input slice should be variant only to the other two maps labeled "dim 1" and "dim 2" (the B and J of the segspace).

Inspecting the IR right before entering loop tiling, I see that the first redomap slice, index_8309, comes from xsss as expected. However the second redomap slice, map2_arg2_8307, comes from some 4D array map2_arg2_r_r_r_8279 -- but ysss is supposed to be 3D, and map2_arg2_r_r_r_8279 seems to be indexed using gtid_8301, which corresponds to the I dimension of the space, i.e. one of the dimensions on which ysss is supposed to be invariant (see below snippet).

In other words, the size [Q] slice of ysss going into the redomap is now variant to one more outer map dimension than it should be.

let {map2_arg2_8307 : [Q₄_6857]f32} =
  map2_arg2_r_r_r_8279[gtid_8301, gtid_8302, gtid_8303, 0i64 :+ Q₄_6857 * 1i64]
let {index_8309 : [Q₄_6857]f32} =
  xsss_6860[0i64 :+ Q₄_6857 * 1i64, gtid_8304, gtid_8301]
athas commented 8 months ago

It ultimately occurs because the program prior to flattening (--extract-kernels) is not a perfect nest of maps. The relevant part looks like this:

  let {defunc_0_map_res_7279 : [I_6458][B_6460][J_6461][A_6457]f32} =
    map(I_6458,
        {iota_res_7223},
        \ {eta_p_7225 : i64}
          : {[B_6460][J_6461][A_6457]f32} ->
          let {defunc_0_map_res_7278 : [B_6460][J_6461][A_6457]f32} =
            map(B_6460,
                {ysss_6463},
                \ {ysss_elem_7254 : [Q_6459][J_6461]f32}
                  : {[J_6461][A_6457]f32} ->
                  let {defunc_0_map_res_7277 : [J_6461][A_6457]f32} =
                    map(J_6461,
                        {iota_res_7227},
                        \ {eta_p_7232 : i64}
                          : {[A_6457]f32} ->
                          let {map2_arg2_7233 : [Q_6459]f32} =
                            ysss_elem_7254[0i64 :+ Q_6459 * 1i64, eta_p_7232]
                          let {defunc_0_map_res_7276 : [A_6457]f32} =
                            map(A_6457,
                                {iota_res_7228},
                                \ {eta_p_7235 : i64}
                                  : {f32} ->
                                  let {index_7275 : [Q_6459]f32} =
                                    xsss_6462[0i64 :+ Q_6459 * 1i64, eta_p_7235, eta_p_7225]
                                  let {defunc_res_7274 : f32} =
                                    redomap(Q_6459,
                                            {index_7275, map2_arg2_7233},
                                            {\ {eta_p_7242 : f32,
                                                eta_p_7243 : f32}
                                              : {f32} ->
                                              let {+_res_7244 : f32} =
                                                fadd32(eta_p_7242, eta_p_7243)
                                              in {+_res_7244},
                                            {0.0f32}},
                                            \ {eta_p_7256 : f32,
                                               eta_p_7257 : f32}
                                              : {f32} ->
                                              let {defunc_0_f_res_7258 : f32} =
                                                fmul32(eta_p_7256, eta_p_7257)
                                              in {defunc_0_f_res_7258})
                                  in {defunc_res_7274})
                          in {defunc_0_map_res_7276})
                  in {defunc_0_map_res_7277})
          in {defunc_0_map_res_7278})

During flattening, that slice (map2_arg2_7233) is distributed. A distributed slice is implemented as a SegMap containing a slice.

athas commented 8 months ago

I removed the "somewhat inefficiently" part; that is in isolation the right way to flatten a slice.

athas commented 8 months ago

If this is a problem, it would certainly be possible to get rid of the segmap, by extending our index-simplification rules a bit.

sortraev commented 8 months ago

Thanks! I think I got the gist of it.

For now, one condition I require is that each redomap array must be variant to at least one and at most (k-1) of k inner dimensions, and conversely, for each inner dimension there is exactly one redomap array variant to it. The latter condition here is a simplifying assumption which should eventually be elided, perhaps replaced with something along the lines of "for each inner dimension there is at least one and at most (n-1) redomap arrays variant to it, for n redomap arrays", but I'm not quite there yet.

I can't say for certain whether this condition cannot be expressed differently, or that it wouldn't be made redundant under other conditions which I have yet to formulate, but for now I'd say it is a bit in the way of determining variance, yes.

Regardless, I have plenty of example source programs that do not provoke the thing that I should be able to reach a working prototype, so there's no rush to do anything about it. On the other hand, if you have suggestions on a different strategy to implement in my module that would not require changes elsewhere, I'm all ears!

athas commented 8 months ago

Programming with explicit indexes is wrong. It should be done with replicates. But I think I will implement the simplifications anyway - it is not so difficult.

sortraev commented 8 months ago

Programming with explicit indexes is wrong.

I agree! I only tested that version of the program because the method of using combinations of flatten/transpose/unflatten to obtain the correct permutation of indices can be a little tedious (until you get used to it and start working them out step by step), and so I imagined some users (especially newbs) might opt to doing it that way, even if I do agree that bad coding style warrants punishment in most cases.

It should be done with replicates.

Oo, this sounds interesting. Can you demonstrate?

But I think I will implement the simplifications anyway - it is not so difficult.

Awesome! :)