tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
51 stars 7 forks source link

Support tiled data movement in TTMetal #482

Closed rpavlovicTT closed 2 weeks ago

rpavlovicTT commented 2 weeks ago

Fix #322

This commit adds support for data movement between tiled layouts. Few key pieces:

  1. TiledLinearMap - original linear space map translated into tile space. Example:

    // tensor: tensor<2x3x64x128xf32,
    //        #tt.layout<(d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3),
    //        undef, <1x1>, memref<384x128xf32, #tt.memory_space<l1>>>> 
    // scalar linear: (d0, d1, d2, d3) -> (d0 * 192 + d1 * 64 + d2, d3), 
    // tile linear: (d0, d1, d2, d3) -> (d0 * 6 + d1 * 2 + d2 floordiv 32, d3 floordiv 32).
  2. TiledShape - same as original except that dims are computed with TiledLinearMap. Given above example:

    // originalScalarShape: [384, 128]
    // tiledShape: [12, 4]
  3. postLinearIndexToVisitingIndex - cpp map that keeps mapping between tiled indexes and initial linear index. Then when tile space is traversed we can easily map to linear space and then to physical space.

Added few unit tests that can be run on silicon with:

ttmlir-opt --ttir-load-system-desc="path=n300.ttsys" --ttir-implicit-device
  --ttir-allocate --convert-ttir-to-ttmetal --ttmetal-serialize-to-binary="output=out.ttm"
  test/ttmlir/Silicon/TTMetal/tiled_reblock.mlir
ttrt run --program-index <idx> --identity --atol 1e-02 out.ttm

They should pass since no data modification is done.

rpavlovicTT commented 2 weeks ago

Hey @rpavlovicTT, this looks great! There are a few minor comments inline that I think we should address before landing and then I'm cool with landing it as is.

I think though we should have a follow on change that gets rid of the postLinearIndexToVisitingIndex map, I think this map will get quite large as the size of the tensor increases and it requires additional lookup.

I might suggest that we eventually make the following changes:

1. We could make the `getTileShape` function read something like, i.e. the affine map for tiled shape just gets folded inline, because it's separate from `getTileLinearMap` (see next point):
tensorShape = [d-1 for d in tensorShape]
mlir::AffineMap linear = getLinear();
auto rank = linear.getNumResults();
assert(rank >= 2);
mlir::AffineExpr y = linear.getResult(rank - 2, getContext());
mlir::AffineExpr x = linear.getResult(rank - 1, getContext());
mlir::AffineMap tiled = linear.replace(DenseMap{{y, y.ceilDiv(tileH)}, {x, x.ceilDiv(tileW)}});
tileShape = tiled.compose(tensorShape)
return [d+1 for d in tileShape]
2. We repurpose `getTileLinearMap` to perform a different transformation, i.e. give an affine map that can accept a tile shape from above, this will enable us to make 0 modification to the `calculateDataMovement` interface. I keep going back and forth in my head, I feel like I'm overlooking something major, but I think this becomes an identity function (and presumably we have to assert that `getLinear().isPermutation() == false`.
mlir::AffineMap linear = getLinear();
return mlir::AffineMap::getMultiDimIdentityMap(linear.getNumResults(), context);
3. Calling `calculateDataMovement` becomes something like:
if (isTiled)
  auto srcTiled = inputLayout.projectOnto(inputLayout.getTileLinearMap(), ....);
  auto dstTiled = outputLayout.projectOnto(outputLayout.getTileLinearMap(), ...);
  calculateDataMovement(inputLayout.getTileShape(), inputLayout.getElementSizeBytes(), srcTiled, dstTiled);
else
  calculateDataMovement(inputTy.getShape(), inputLayout.getElementSizeBytes(), src, dst);
4. Glossing over the details of how `projectOnto` needs to change, I think the main thing that needs change is `auto logicalShardShape = calculateLogicalShardShape`.  This needs to be in terms of tiles and not scalars.

Nick, thanks for the detailed explanation. I am aware that postLinearIndexToVisitingIndex is not the optimal approach, but I guess we can commit and optimize later. As for your 2nd point and the identity claim you made, you are somehow ignoring whatever expression linear mapping was describing (just keeping the rank)? That's what I am having trouble understanding obviously, but I'll give it a try and see on an example.

nsmithtt commented 2 weeks ago

Nick, thanks for the detailed explanation. I am aware that postLinearIndexToVisitingIndex is not the optimal approach, but I guess we can commit and optimize later. As for your 2nd point and the identity claim you made, you are somehow ignoring whatever expression linear mapping was describing (just keeping the rank)? That's what I am having trouble understanding obviously, but I'll give it a try and see on an example.

Yeah agreed, although it does slightly complicate the calculateDataMovement pass because now there is extra state that needs to be managed with this hash map.

Yes, we are ignoring the expression in the linear map, but I think this is what I previously haven't been communicating super well. We first prove that the linear expression doesn't change between input and output, if this holds then it seems safe to reinterpret the input and output as just some bag of tiles that needs to be transferred. Within the tiles the layout might be collapsed and have some complicated affine expression, but since we proved that input and output have the same complicated affine expression, it can be reinterpreted as an identity transformation where the only legal attributes that can change is grid shape and memory space.

rpavlovicTT commented 2 weeks ago

Hey @nsmithtt I updated the changes accordingly to your idea, please take another look :)