plaidml / tpp-mlir

TPP experimentation on MLIR for linear algebra
https://arxiv.org/abs/2404.15204
Other
111 stars 31 forks source link

Loop tiling, shuffle and expansion passes #922

Closed KavithaTipturMadhu closed 3 months ago

alheinecke commented 5 months ago

I haven't done a detailed review, but it seems the old 2d parallelization is disabled/removed from tpp-run. I suggest for now we keep the old one functional in the default pass and make the new one available through a different knob. Then we can compare both and retire the old one, when the new one completely covers the old usecases.

KavithaTipturMadhu commented 5 months ago

I am working on unit tests for the passes as well as MLP integration tests, in the meantime we can start the review.

KavithaTipturMadhu commented 5 months ago

Performance numbers are comparable.

adam-smnk commented 5 months ago

I'm playing with the passes to get a better idea of the transforms but I'm running into an error I don't really get why. Could you help me out with this case?

Input kernel:

#map = affine_map<(d0) -> (d0 * 32)>
func.func @entry(%arg0: memref<32x32x32x32xbf16>, %arg1: memref<32x32x16x32x2xbf16>, %arg2: memref<1024x1024xbf16>) {
  %c32_i64 = arith.constant 32 : i64
  scf.forall (%arg3, %arg4) in (32, 32) {
    %0 = affine.apply #map(%arg3)
    %1 = affine.apply #map(%arg4)
    %subview = memref.subview %arg0[%arg3, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
    %subview_0 = memref.subview %arg1[%arg4, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
    %subview_1 = memref.subview %arg2[%0, %1] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
    %2 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 1024, 1024, 1024] flags = (vnni_b) data_type = bf16
    xsmm.brgemm(data_type = bf16, %2, %subview, %subview_0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xbf16, strided<[1024, 1], offset: ?>>, i64) -> ()
  }
  return
}

Pipeline: tpp-opt ../test.mlir -loop-insertion-pass="M-tile-shape=2,4 N-tile-shape=4,8"

I end up with an error: require m tile shape to match tensor shape

KavithaTipturMadhu commented 5 months ago

I'm playing with the passes to get a better idea of the transforms but I'm running into an error I don't really get why. Could you help me out with this case?

Input kernel:

#map = affine_map<(d0) -> (d0 * 32)>
func.func @entry(%arg0: memref<32x32x32x32xbf16>, %arg1: memref<32x32x16x32x2xbf16>, %arg2: memref<1024x1024xbf16>) {
  %c32_i64 = arith.constant 32 : i64
  scf.forall (%arg3, %arg4) in (32, 32) {
    %0 = affine.apply #map(%arg3)
    %1 = affine.apply #map(%arg4)
    %subview = memref.subview %arg0[%arg3, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
    %subview_0 = memref.subview %arg1[%arg4, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
    %subview_1 = memref.subview %arg2[%0, %1] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
    %2 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 1024, 1024, 1024] flags = (vnni_b) data_type = bf16
    xsmm.brgemm(data_type = bf16, %2, %subview, %subview_0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xbf16, strided<[1024, 1], offset: ?>>, i64) -> ()
  }
  return
}

Pipeline: tpp-opt ../test.mlir -loop-insertion-pass="M-tile-shape=2,4 N-tile-shape=4,8"

I end up with an error: require m tile shape to match tensor shape

The tile shapes are expected to match the leading dimensions of the tensors. In this case, arg0's leading dimension is 32, therefore the M-tile-shape must be factors of 32. The same is true for N-tile-shape i.e., the tile factors are expected to be factors of 32 which is the leading dimension of arg1.

adam-smnk commented 5 months ago

The tile shapes are expected to match the leading dimensions of the tensors. In this case, arg0's leading dimension is 32, therefore the M-tile-shape must be factors of 32. The same is true for N-tile-shape i.e., the tile factors are expected to be factors of 32 which is the leading dimension of arg1.

Thanks, that helped me go a bit further.

There is some issue with the pipeline currently. When I try to apply loop-insertion-pass on its own with tpp-opt, then it fails to tile. However, tpp-run works with any random M-tile-shape and N-tile-shape arguments. Since the pass does not signal any failures, I think pass manager simply continues. I can see some changes in IR at the end with print-mlir=mid so, other passes still kick in. But I think loop-insertion currently does nothing.

KavithaTipturMadhu commented 5 months ago
#map = affine_map<(d0) -> (d0 * 32)>
func.func @entry(%arg0: memref<32x32x32x32xbf16>, %arg1: memref<32x32x16x32x2xbf16>, %arg2: memref<1024x1024xbf16>) {
  %c32_i64 = arith.constant 32 : i64
  scf.forall (%arg3, %arg4) in (32, 32) {
    %0 = affine.apply #map(%arg3)
    %1 = affine.apply #map(%arg4)
    %subview = memref.subview %arg0[%arg3, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>
    %subview_0 = memref.subview %arg1[%arg4, 0, 0, 0, 0] [1, 32, 16, 32, 2] [1, 1, 1, 1, 1] : memref<32x32x16x32x2xbf16> to memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
    %subview_1 = memref.subview %arg2[%0, %1] [32, 32] [1, 1] : memref<1024x1024xbf16> to memref<32x32xbf16, strided<[1024, 1], offset: ?>>
    %2 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 1024, 1024, 1024] flags = (vnni_b) data_type = bf16
    xsmm.brgemm(data_type = bf16, %2, %subview, %subview_0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>, memref<32x32xbf16, strided<[1024, 1], offset: ?>>, i64) -> ()
  }
  return
}

You might be seeing the shuffle and expansion passes and subsequent passes doing something in tpp-run despite tiling failing to do anything. The current implementation is such that the other passes should continue to work despite loop insertion not working, which happens in case of tpp-opt as well, with multiple passes. Should this behavior change? @rengolin @alheinecke

adam-smnk commented 5 months ago

The current implementation is such that the other passes should continue to work despite loop insertion not working, which happens in case of tpp-opt as well, with multiple passes.

That's perfectly fine for the pass to be opportunistic.

Looking at it a bit closer, my point it that more transformations kick in for mlir-gen --kernel=args --batch=256 --layers=1024,1024 --tiles=32,32,32 compared to mlir-gen --kernel=args --batch=256 --layers=1024,1024. I think there's an issue(?) handling cases when not all BRGEMM arguments are prepacked. I would expect it to work for either but maybe it's intentional?

I think I'll wait for lit tests with further review as I'm not sure what I should expect currently.

KavithaTipturMadhu commented 5 months ago

The current implementation is such that the other passes should continue to work despite loop insertion not working, which happens in case of tpp-opt as well, with multiple passes.

That's perfectly fine for the pass to be opportunistic.

Looking at it a bit closer, my point it that more transformations kick in for mlir-gen --kernel=args --batch=256 --layers=1024,1024 --tiles=32,32,32 compared to mlir-gen --kernel=args --batch=256 --layers=1024,1024. I think there's an issue(?) handling cases when not all BRGEMM arguments are prepacked. I would expect it to work for either but maybe it's intentional?

I think I'll wait for lit tests with further review as I'm not sure what I should expect currently.

BRGEMM arguments are expected to be tiled, yes. I have added a check to ensure that arg2 has the same rank as arg0, I had missed that, thanks for reminding me @adam-smnk.

KavithaTipturMadhu commented 5 months ago

Unit tests have been updated. I have propogated the failures in loopexpansion and loopshuffle passes as warnings so that the pipeline does not stall, as there are loops that do not have same dimensions as the brgemm loops in mlp. Please review this @rengolin @adam-smnk @alheinecke.

rengolin commented 5 months ago

Benchmark regression for OMP 16 threads (~5%). May be nothing, as the DNN benchmarks also change for 16 threads (~6%). I'll trigger the run again to make sure this is a fluke.

rengolin commented 5 months ago

Benchmark regression for OMP 16 threads (~5%). May be nothing, as the DNN benchmarks also change for 16 threads (~6%). I'll trigger the run again to make sure this is a fluke.

New run is 7% slower on MLP 16 x bf16 only, there may be something to it. Could be bf16 fusion. Did you compare the IR emitted between main and your PR?

adam-smnk commented 5 months ago

New run is 7% slower on MLP 16 x bf16 only

OMP 16 usually varies the most but this is larger than usual. I'd suggest double checking by manually running benchmarks on main branch and this PR directly on cluster.

KavithaTipturMadhu commented 5 months ago

Benchmark regression for OMP 16 threads (~5%). May be nothing, as the DNN benchmarks also change for 16 threads (~6%). I'll trigger the run again to make sure this is a fluke.

New run is 7% slower on MLP 16 x bf16 only, there may be something to it. Could be bf16 fusion. Did you compare the IR emitted between main and your PR?

Fusion seems to work just fine, I compared the IRs and they are similar. I measured the time taken by my PR and parallel task grid configuration and the times are comparable on spr machine. In fact, parallel-task-grid configuration seems to be taking more time. Parallel task grid time is 0.000100212 whereas tiling time is 9.32289e-05. Edit: @rengolin You're right, fusion does seem to not work, I see brgemm and unary and binary ops being listed separately in tpp-opt. I'll look into this. I was looking at the performance numbers of the my branch for the two configurations. I didn't look at the IR correctly the first time, the IR seems to have BRGEMM and unary and binary instead of fused brgemm, looks like combinexsmm doesn't work, I cross checked this with tpp-opt. Edit2: I have modified the pipeline, fusion seems to be working now, I have started benchmarks on the PR. I will update the patch with mlp integration tests in the meantime. Edit3: I have updated the patch to take just 1 tiling factor as per @alheinecke's suggestion, with the second tiling factor inferred from the tensor's leading dimension, but the code is capable of doing more that 2d tiling. I've added an MLP integration test to ensure that fusion works. I'll rerun benchmarks.

adam-smnk commented 4 months ago

Some benchmark errors and regressions. Reminder to revisit these later.

KavithaTipturMadhu commented 4 months ago

Review on loop insertion. This is more like parallel tiling, why is is called "insertion"?

Earlier, the plan was the "insert" an scf.parallel loop around the brgemm op, that was why I started with the name Loop insertion instead of Loop parallelization.

KavithaTipturMadhu commented 3 months ago

Closing this PR since we're moving this pass to vector dialect.