nod-ai / iree-amd-aie

IREE plugin repository for the AMD AIE accelerator
Apache License 2.0
46 stars 23 forks source link

Matmul + Vectorization + Objectfifo #431

Open Abhishek-Varma opened 2 weeks ago

Abhishek-Varma commented 2 weeks ago

We already have an e2e Matmul working with the in-flight Objectfifo backend. The current in-flight branch being maintained by @jtuyls is https://github.com/nod-ai/iree-amd-aie/tree/jornt_cpp_pipeline.

Currently we're trying to support the same but with the vectorization switched on. I'm maintaining that on top of @jtuyls 's branch : https://github.com/nod-ai/iree-amd-aie/tree/avarma_matmul_elem

Issues :-

The current IR log state : e2e Matmul + Vectorization + Objectfifo

yzhang93 commented 2 weeks ago

@Abhishek-Varma I've rebased your fix commits on top of Jorn's latest branch which is ahead of your branch by several important commits. I also fixed another problem in distribute-core-and-objectfifo-pass, and now the lowered IR looks correct to me after this pass. Also with the change, it no longer fails after AccessToAcquireRelease, but fails at a later pass create-logical-objectfifo-link.

Please refer to this new branch: https://github.com/nod-ai/iree-amd-aie/tree/objectfifo_vectorization And the latest dump IR: matmul_objectfifo_vectorize.txt

Abhishek-Varma commented 2 weeks ago

Hi @yzhang93 - I didn't get any failure in create-logical-objectfifo-link.

The failure with the https://github.com/nod-ai/iree-amd-aie/tree/objectfifo_vectorization branch occurs at lower-to-aie pass instead (the final pass in the pipeline).

And that failure turned out to a red-herring, the main issue is at distribute-core-and-objectfifo pass which is caused due to the following structure we get due to fold-memref-alias pass :-

%a = memref.subview %alloc
// PROLOGUE
aie.core {
     linalg.fill (%cst, %a)
     scf.for // vectorized matmul loop nest 1
         scf.for // vectorized matmul loop nest 2
             scf.for // vectorized matmul loop nest 3
                  %b = memre.subview %alloc
                  linalg.generic ... outs(%b)
}
// MAIN
aie.core {
  scf.for
   scf.for // vectorized matmul loop nest 1
         scf.for // vectorized matmul loop nest 2
             scf.for // vectorized matmul loop nest 3
                  %b = memre.subview %alloc
                  linalg.generic ... outs(%b)
}
// EPILOGUE
aie.core {
   scf.for // vectorized matmul loop nest 1
         scf.for // vectorized matmul loop nest 2
             scf.for // vectorized matmul loop nest 3
                  %b = memre.subview %alloc
                  linalg.generic ... outs(%b)
}

%a is memref<2x2x8x8x4x4xi32, 2 : i32> to memref<1x1x8x8x4x4xi32, 2> (only considering shape) %b is memref<2x2x8x8x4x4xi32, 2 : i32> to memref<1x1x1x1x4x4xi32, 2> (again, only considering shape)

We get the following error message from lower-to-aie pass because the AIE objectfifo link op fails the verification :-

<unknown>:0: error: Total size of input objFifos in ObjectFifoLinkOp must be equal to size of output objFifo
<unknown>:0: note: see current operation: "aie.objectfifo.link"() <{fifoIns = [@obj6, @obj7, @obj8, @obj9], fifoOuts = [@obj10]}> : () -> ()

And the reason for that is that the current logic in distribute-core-and-objectfifo incorrectly replaces %a (which is memref<1x1x8x8x4x4xi32, 2>) with a new alloc op (which is memref<1x1x1x1x4x4xi32, 2>).

Therefore we need to change distributeLocalMemory function to accommodate the same (CC: @MaheshRavishankar @jtuyls ).

Here is the gist of the e2e IR's current state : e2e IR vectorization.

I'm hoping that with the above suggested fix we might be able to get through to AIE Dialect.

yzhang93 commented 2 weeks ago

I tried two ways to solve the problem, but haven't got it work.

  1. I tried to directly change distributeLocalMemory function by replacing the subview ops with the new allocation.

I was able to replace %subview = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1] : memref<2x2x8x8x4x4xi32, 2 : i32> to memref<1x1x8x8x4x4xi32, strided<[2048, 1024, 128, 16, 4, 1], offset: ?>, 2 : i32>, but had problem with this %subview_13 = memref.subview %alloc_3[%arg2, %arg3, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<2x2x8x8x4x4xi32, 2 : i32> to memref<1x1x1x1x4x4xi32, strided<[2048, 1024, 128, 16, 4, 1], offset: ?>, 2 : i32>.

I think instead of the above one subview, we probably still need two subview ops, i.e., subview ops before FoldMemRefAliasOps pass.

%subview_13 = memref.subview %alloc_3[%arg2, %arg3, 0, 0, 0, 0] [1, 1, 8, 8, 4, 4] [1, 1, 1, 1, 1, 1] : memref<2x2x8x8x4x4xi32, 2 : i32> to memref<1x1x8x8x4x4xi32, strided<[2048, 1024, 128, 16, 4, 1], offset: ?>, 2 : i32>               

%subview_16 = memref.subview %subview_13[0, 0, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x8x8x4x4xi32, strided<[2048, 1024, 128, 16, 4, 1], offset: ?>, 2 : i32> to memref<1x1x1x1x4x4xi32, strided<[2048, 1024, 128, 16, 4, 1], offset: ?>, 2 : i32>

So I come to the second way to approach it.

  1. I hacked inside the upstream codes in FoldMemRefAliasOps to avoid folding these two subview ops, and got IR dump as https://gist.github.com/yzhang93/48b5a4fb60dda465158f9dabbab14d91. And it now has a different error when running DistributeCoresAndObjectFifos pass.

I'm not fully sure which way is more reasonable and can eventually fix the problem.

Abhishek-Varma commented 2 weeks ago

Hi @yzhang93 - so I worked on adding a fix from scratch within distribute-core-and-objectfifo pass itself and was able to make it work.

There were two other issues afterwards in the same pass :- i. Objectfifo needs to be attached tiles using findUsersInCoreAndAddTiles as a helper method. ii. The other was for inserting logical objectfifo access for linalg op's operands using insertLogicalObjectFifoAccess method.

I triaged and found that I have already solved those two specific issues in avarma_matmul_elem branch but saw that a few part of the code snippets pertaining to the fixes were deleted/modified as part of your objectfifo_vectorization branch - I'm not sure why though. Could you help explain? Because the fixes I added looks good to me and is getting us the expected IR as discussed - so I'm trying to understand the rationale.

I've pushed all changes on branch avarma_matmul_vectorization_objectfifo instead.

Here is the current state of the e2e IR log now after the above fixes : e2e IR matmul + vectorization + objectfifo

CC: @MaheshRavishankar @jtuyls

yzhang93 commented 2 weeks ago

I triaged and found that I have already solved those two specific issues in avarma_matmul_elem branch but saw that a few part of the code snippets pertaining to the fixes were deleted/modified as part of your objectfifo_vectorization branch - I'm not sure why though.

As discussed offline, with your previous branch it didn't generate the correct IR. There were three problems: 1) access op doesn't have the correct memory pattern; 2) the inputs doesn't broadcast correctly; 3) it generates an incorrect allocation as pointed in the above comment https://github.com/nod-ai/iree-amd-aie/issues/431#issuecomment-2176505316.

My branch objectfifo_vectorization was able to solve the first two issues but not the third one.

Now with your new branch, the third problem has been fixed however the first two are remaining issues. I'll take another look today to see if we can solve all of the three problems at the same time.

Abhishek-Varma commented 2 weeks ago

So @yzhang93 and I discussed this offline.

Context: The distribute-core-and-objectfifo pass is structured in the following way :-

    ....
    A: distributeLocalMemory (this involves the AllocOp fix)
    ...
    B: insertLogicalObjectFifoAccess (this involves the issue pointed above in `point ii`)
    ...
    C: assignLocalTiles (the one invoking the helper method findUsersInCoreAndAddTiles as mentioned above in `point i`)
    ...

The structure of the current has the alloc op replacements correct - but we also need the following (when comparing with Matmul without vectorization) :-

  1. We also need to have amdaie.logicalobjectfifo.access(%SSA, Read/Write) in the IR which currently the pass is setting all as amdaie.logicalobjectfifo.access(%SSA, None).

  2. We also need to have broadcasted allocs which currently (I just saw while adding this update) we get as :-

      %13 = amdaie.logicalobjectfifo.from_memref %alloc, {%tile_17} : memref<1x1x8x4x8x4xi32, 2 : i32> -> !amdaie.logicalobjectfifo<memref<1x1x8x4x8x4xi32, 2 : i32>>
      %14 = amdaie.logicalobjectfifo.from_memref %alloc, {%tile_18} : memref<1x1x8x4x8x4xi32, 2 : i32> -> !amdaie.logicalobjectfifo<memref<1x1x8x4x8x4xi32, 2 : i32>>
      %15 = amdaie.logicalobjectfifo.from_memref %alloc, {%tile_19} : memref<1x1x8x4x8x4xi32, 2 : i32> -> !amdaie.logicalobjectfifo<memref<1x1x8x4x8x4xi32, 2 : i32>>
      %16 = amdaie.logicalobjectfifo.from_memref %alloc, {%tile_20} 

    In case of Matmul it is %13 = amdaie.logicalobjectfifo.from_memref %alloc, {%tile_17, %tile_18} (@yzhang93 correct me if I'm wrong).

There were two other issues afterwards in the same pass :- i. Objectfifo needs to be attached tiles using findUsersInCoreAndAddTiles as a helper method. ii. The other was for inserting logical objectfifo access for linalg op's operands using insertLogicalObjectFifoAccess method.

I triaged and found that I have already solved those two specific issues in avarma_matmul_elem branch but saw that a few part of the code snippets pertaining to the fixes were deleted/modified as part of your objectfifo_vectorization branch - I'm not sure why though. Could you help explain? Because the fixes I added looks good to me and is getting us the expected IR as discussed - so I'm trying to understand the rationale.

So the fixes which @yzhang93 added was part of B and C to match it with that of the Matmul. But after fixing A it crashed.

Therefore now we need to address B and C with the context of the AllocOps getting replaced.

CC: @MaheshRavishankar @jtuyls

yzhang93 commented 2 weeks ago

@Abhishek-Varma Based on your new branch, I've fixed the other two issues in DistributeCoresAndObjectFifos pass. So now the three issues we observed earlier should be all fixed.

Please check the current IR dump here https://gist.github.com/yzhang93/d5cbb97cf2790c6328fd3daef2a34814.

After the fix, now the lowering with vectorization fails at AMDAIELowerToAIE with error

<stdin>:20:16: error: different memory spaces specified for base memref type 'memref<1x1x4x8x4x8xi32, 1>' and subview memref type 'memref<1x1x1x1x4x8xi32, strided<[1024, 1024, 256, 32, 8, 1], offset: ?>, 2 : i32>'

This is simply because AMDAIELowerToAIE doesn't have a function to rewrite the memref.subview ops which are the operands of linalg.generic ops. And now the memory space of the subview ops should be rewritten to 1 according to https://github.com/nod-ai/iree-amd-aie/blob/avarma_matmul_vectorization_objectfifo/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.cpp#L147.

Could you please help solve this problem today? Thanks.

Abhishek-Varma commented 2 weeks ago

I've added the fix for handling memref::SubViewOp in lower-to-aie and the changes are in : avarma_matmul_vectorization_objectfifo branch.

We are now able to get to the final AIE dialect IR.

Here is the current state of the IR : e2e IR Matmul + Vectorization + Objectfifo. I tried skimming through the Matmul without vectorization IR - overall structure looks okay to me.

CC: @MaheshRavishankar @jtuyls @yzhang93

yzhang93 commented 2 weeks ago

@Abhishek-Varma With your fix in LowerToAIE, it generates bad codes as below

scf.for %arg1 = %c0 to %c8 step %c1 {
          scf.for %arg2 = %c0 to %c8 step %c1 {
            scf.for %arg3 = %c0 to %c4 step %c1 {
              linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%reinterpret_cast_11, %reinterpret_cast_10 : memref<1x1x4x8x4x8xi32, 1>, memref<1x1x8x4x8x4xi32, 1>) outs(%reinterpret_cast : memref<1x1x8x8x4x4xi32, 1>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [32, 32, 32], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
              ^bb0(%in: i32, %in_14: i32, %out: i32):
                %13 = arith.muli %in, %in_14 : i32
                %14 = arith.addi %out, %13 : i32
                linalg.yield %14 : i32
              }
            }
          }
        }

The subview ops were gone and the linalg.generic operands were not using the vectorization instruction sizes. I pushed another commit in the same branch which has fixed the above issue.

Now it can compile and generate vmfb, but has numerical issue. For example, with the inputs

iree-run-module --device=xrt --module=pack_peel.vmfb \
  --input=128x256xi32=1 --input=256x128xi32=2 --function=matmul_i32

All the values in the output matrix should be 256, but what I got have zeros every 32 elements.

128x128xi32=[256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0][256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0][256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0][256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0][256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0][256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0][256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0][256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...]
yzhang93 commented 1 week ago

Now it can compile and generate vmfb, but has numerical issue. For example, with the inputs

iree-run-module --device=xrt --module=pack_peel.vmfb \
  --input=128x256xi32=1 --input=256x128xi32=2 --function=matmul_i32

All the values in the output matrix should be 256, but what I got have zeros every 32 elements.

128x128xi32=[256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0][256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0][256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0][256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0][256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0][256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0][256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0][256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...][...]

I've found the bug, and pushed a fix in avarma_matmul_vectorization_objectfifo branch.

The issue was because when the subview op is replaced from memref<2x2x8x8x4x4xi32, 2 : i32> to memref<1x1x8x8x4x4xi32, 2 : i32>, the offsets of the first two dimensions were not properly changed.

After all the fixes, objectfifo + vectorization pipeline now works for matmul example and generates correct results.

MaheshRavishankar commented 1 week ago

Great job folks!!! This puts us on a great path! Looking forward to changing all matmuls to go down this path

yzhang93 commented 1 week ago

@Abhishek-Varma and I will continue using this issue as a tracker for debugging and fixing issues for bf16 vectorization.

With nothing changed from the previous branch, it crashed at DistributeCoresAndObjectFifos. But when I looked at the IR, the IR was not generated correctly before that. In short words, the current passes don't have support for vector dialect and operations.

The codes within a core for i32 type, for example:

%22 = amdaie.core(%tile) {
          amdaie.logicalobjectfifo.consume(%19)
          amdaie.logicalobjectfifo.consume(%20)
          linalg.fill ins(%c0_i32 : i32) outs(%subview : memref<1x1x8x8x4x4xi32, strided<[2048, 1024, 128, 16, 4, 1], offset: ?>, 2 : i32>)
          scf.for %arg4 = %c0 to %c8 step %c1 {
            scf.for %arg5 = %c0 to %c8 step %c1 {
              scf.for %arg6 = %c0 to %c4 step %c1 {
                %subview_5 = memref.subview %alloc_0[0, 0, %arg6, %arg4, 0, 0] [1, 1, 1, 1, 4, 8] [1, 1, 1, 1, 1, 1] : memref<1x1x4x8x4x8xi32, 2 : i32> to memref<1x1x1x1x4x8xi32, strided<[1024, 1024, 256, 32, 8, 1], offset: ?>, 2 : i32>
                %subview_6 = memref.subview %alloc[0, 0, %arg5, %arg6, 0, 0] [1, 1, 1, 1, 8, 4] [1, 1, 1, 1, 1, 1] : memref<1x1x8x4x8x4xi32, 2 : i32> to memref<1x1x1x1x8x4xi32, strided<[1024, 1024, 128, 32, 4, 1], offset: ?>, 2 : i32>
                %subview_7 = memref.subview %alloc_3[%arg2, %arg3, %arg5, %arg4, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<2x2x8x8x4x4xi32, 2 : i32> to memref<1x1x1x1x4x4xi32, strided<[2048, 1024, 128, 16, 4, 1], offset: ?>, 2 : i32>
                linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%subview_5, %subview_6 : memref<1x1x1x1x4x8xi32, strided<[1024, 1024, 256, 32, 8, 1], offset: ?>, 2 : i32>, memref<1x1x1x1x8x4xi32, strided<[1024, 1024, 128, 32, 4, 1], offset: ?>, 2 : i32>) outs(%subview_7 : memref<1x1x1x1x4x4xi32, strided<[2048, 1024, 128, 16, 4, 1], offset: ?>, 2 : i32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64], [0, 0, 1], [1, 1, 0, 0, 0, 0]]>, packing_config = #amdaie.packing_config<packing_config = [{packedSizes = [32, 32, 32], transposePackIndices = [1], unpackEmpty = [false], innerPerm = [[1, 0]], outerPerm = [[0, 1]]}, {packedSizes = [0, 0, 0, 4, 4, 8], transposePackIndices = [0, 1, 2], unpackEmpty = [false, false, true], innerPerm = [[0, 1], [1, 0], [0, 1]], outerPerm = [[0, 1, 3, 2], [0, 1, 3, 2], [0, 1, 3, 2]]}]>} {
                ^bb0(%in: i32, %in_8: i32, %out: i32):
                  %23 = arith.muli %in, %in_8 : i32
                  %24 = arith.addi %out, %23 : i32
                  linalg.yield %24 : i32
                }
              }
            }
          }
          amdaie.end
        }

Now for bf16 becomes:

%22 = amdaie.core(%tile) {
          amdaie.logicalobjectfifo.consume(%19)
          amdaie.logicalobjectfifo.consume(%20)
          linalg.fill ins(%cst : bf16) outs(%subview : memref<1x1x16x16x4x4xbf16, strided<[8192, 4096, 256, 16, 4, 1], offset: ?>, 2 : i32>)
          scf.for %arg4 = %c0 to %c16 step %c1 {
            scf.for %arg5 = %c0 to %c16 step %c1 {
              scf.for %arg6 = %c0 to %c8 step %c1 {
                %23 = vector.transfer_read %alloc_0[%c0, %c0, %arg6, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x16x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
                %24 = vector.transfer_read %alloc[%c0, %c0, %arg5, %arg6, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
                %25 = vector.transfer_read %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<2x2x16x16x4x4xbf16, 2 : i32>, vector<1x1x1x1x4x4xbf16>
                %26 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %23, %24, %25 : vector<1x1x1x1x4x8xbf16>, vector<1x1x1x1x8x4xbf16> into vector<1x1x1x1x4x4xbf16>
                vector.transfer_write %26, %alloc_3[%arg2, %arg3, %arg5, %arg4, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xbf16>, memref<2x2x16x16x4x4xbf16, 2 : i32>
              }
            }
          }
          amdaie.end
        }

I haven't finished adding supports for all places that have vector.contract, but @Abhishek-Varma could take my WIP commit as a start point for adding other similar supports.

Abhishek-Varma commented 1 week ago

Fix commit : Fix distribute-core-and-objectfifo + initial fix to lower-to-aie.

This is on the same branch : avarma_matmul_vectorization_objectfifo.

Current state : e2e IR log

Fixes added :-

  1. vector.transfer_write had to be dealt while distributing local memory - distribute-core-and-objectfifo.
  2. Added support for vector ops for inserting logical objectfifos - this entailed ensuring that the vector.transfer_read and vector.transfer_write from the SAME objectfifo is indeed using the same logicalobjectfifo.access op - distribute-core-and-objectfifo.
  3. WIP initial support for dealing with bf16 memref types at func.func input arguments - lower-to-aie.

CC: @MaheshRavishankar @jtuyls @yzhang93

Abhishek-Varma commented 1 week ago

Current IR state : new e2e IR state

Previous IR state : old e2e IR state

The difference between the two is the func.func (....) part :-

OLD IR
func.func @matmul_i8_i32_dispatch_0_matmul_128x128x256_bf16(%arg0: memref<128x256xi32>, 
                                                            %arg1: memref<256x128xi32>,
                                                            %arg2: memref<128x128xi32>) {
            aiex.npu.dma_memcpy_nd(0, 0, %arg0[1, 0, 0, 0][1, 2, 64, 64][1, 16384, 256]) {id = 0 : i64, issue_token = true, metadata = @obj0} : memref<128x256xi32>
            aiex.npu.dma_wait {symbol = @obj0}

NEW IR
func.func @matmul_i8_i32_dispatch_0_matmul_128x128x256_bf16(%arg0: memref<128x128xi32>, 
                                                            %arg1: memref<256x64xi32>, 
                                                            %arg2: memref<128x64xi32>) {
      aiex.npu.dma_memcpy_nd(0, 0, %arg0[1, 0, 0, 0][1, 2, 64, 32][1, 8192, 128]) {id = 0 : i64, issue_token = true, metadata = @obj0} : memref<128x128xi32>
      aiex.npu.dma_wait {symbol = @obj0}

Changes done within lower-to-aie :-

  1. The LHS, RHS and OUTPUT args' last dimension was halved based on the element type bit width.
  2. The strides in aie.npu.dma_cpy_nd were halved if not 1 - this should be done ONLY if the element type is bf16 (but currently I'm enforcing it anyway - something to handle later when the work is PR-ready).
  3. Those offsets/sizes were also halved where the corresponding strides is 1 as well - note: the last stride is implicitly 1 therefore it is a vector of size 3 instead of 4 like offsets/sizes.

Changes are in the same branch : avarma_matmul_vectorization_objectfifo

Tried going through createEmulateNarrowTypes but that seems to be working through the subspan HAL binding ops, something which lower-to-aie erases after binding the input/objectfifos.

CC: @MaheshRavishankar @jtuyls @yzhang93

Abhishek-Varma commented 1 week ago

Hi.

  1. I first tried looking into the emulate-narrow-types pass a. As @yzhang93 rightly mentioned, it works only for integers so I added few changes to deal with floats while upcasting to integer. b. The pass basically changes ALL memref/vector to the intended upcast data type - but that's NOT what we want (please correct me if I'm wrong). c. I then switched on just the logic for HAL binding subspan op - it does create a linearized i32 version of the respective bf16 types but it didn't work because there's NO way (that I'm aware of) to replace <axbxbf16> with <dxi32>. Based on that the point 1.b above makes sense.

  2. I thought of looking into NON-objectfifo bf16 IR instead - because clearly we aren't dealing with bf16 in AIE codegen for the first time. a. Here is the bf16 IR for the case which works (non-objectfifo) : e2e IR for non-objectfifo bf16 matmul. b. Based on that I tried looking into AIRRtToNpu pass and found how they're adding the upcast after linearizing (CC: @MaheshRavishankar ). c. I used that logic within lower-to-aie and was able to get the current e2e IR log.

  3. Since the func.func change we want looked okay to me (2.c) I tried running that to generate .vmfb (wanted to see if at all it works) - it failed at LowerVectorToAIEVec pass : e2e IR log during .vmfb generation

  4. I can try looking into LowerVectorToAIEVec but I think I might just be chasing a red-herring here.

As per my understanding of 1.b - ALL the structure that's NOT control code in the IR will remain SAME - we just need to update the func.func and the nput.dma_cpy_nd which makes the control code - and this is what 2.b does.

CC: @MaheshRavishankar @jtuyls @yzhang93

jtuyls commented 1 week ago
  • Since the func.func change we want looked okay to me (2.c) I tried running that to generate .vmfb (wanted to see if at all it works) - it failed at LowerVectorToAIEVec pass : e2e IR log during .vmfb generation
  • I can try looking into LowerVectorToAIEVec but I think I might just be chasing a red-herring here.

It seems to crash on converting the vector.contract op to an aievec::matmul op. Could you try using a f32 output instead of bf16 as bf16 output is not supported in the aievec::matmul op: https://github.com/Xilinx/mlir-aie/blob/d850560c77799af96c6361a79e34cb0a8e842c50/include/aie/Dialect/AIEVec/IR/AIEVecOps.td#L869

Abhishek-Varma commented 1 week ago
  • Since the func.func change we want looked okay to me (2.c) I tried running that to generate .vmfb (wanted to see if at all it works) - it failed at LowerVectorToAIEVec pass : e2e IR log during .vmfb generation
  • I can try looking into LowerVectorToAIEVec but I think I might just be chasing a red-herring here.

It seems to crash on converting the vector.contract op to an aievec::matmul op. Could you try using a f32 output instead of bf16 as bf16 output is not supported in the aievec::matmul op: https://github.com/Xilinx/mlir-aie/blob/d850560c77799af96c6361a79e34cb0a8e842c50/include/aie/Dialect/AIEVec/IR/AIEVecOps.td#L869

I tried bf16 input and f32 output/accumulator as the input dispatch.

func.func @matmul_bf16_f32(%lhs: tensor<128x256xbf16>, %rhs: tensor<256x128xbf16>) -> tensor<128x128xf32>
{
  %cst = arith.constant 0.0 : f32
  %0 = tensor.empty() : tensor<128x128xf32>
  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
  %res = linalg.matmul ins(%lhs, %rhs: tensor<128x256xbf16>, tensor<256x128xbf16>)
                    outs(%1: tensor<128x128xf32>) -> tensor<128x128xf32>
  return %res : tensor<128x128xf32>
}

It seems to have yet another issue because this is the loop nest after vectorization :-

scf.for %arg5 = %c0 to %c16 step %c1 {
  scf.for %arg6 = %c0 to %c16 step %c1 {
    scf.for %arg7 = %c0 to %c8 step %c1 {
      %27 = vector.transfer_read %alloc_1[%c0, %c0, %arg7, %arg5, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x16x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
      %28 = vector.transfer_read %alloc[%c0, %c0, %arg6, %arg7, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
      %29 = vector.transfer_read %alloc_4[%arg3, %arg4, %arg6, %arg5, %c0, %c0], %cst_0 {in_bounds = [true, true, true, true, true, true]} : memref<2x2x16x16x4x4xf32, 2 : i32>, vector<1x1x1x1x4x4xf32>
      %30 = arith.extf %27 : vector<1x1x1x1x4x8xbf16> to vector<1x1x1x1x4x8xf32>
      %31 = arith.extf %28 : vector<1x1x1x1x8x4xbf16> to vector<1x1x1x1x8x4xf32>
      %32 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %30, %31, %29 : vector<1x1x1x1x4x8xf32>, vector<1x1x1x1x8x4xf32> into vector<1x1x1x1x4x4xf32>
      vector.transfer_write %32, %alloc_4[%arg3, %arg4, %arg6, %arg5, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xf32>, memref<2x2x16x16x4x4xf32, 2 : i32>
    }
  }
}

So now we have f32 inputs and f32 outputs apparently. Although it crashed in the distribute-core-and-objectfifo pass (and I understand why) this would cause issue again in the lower stack at LowerVectorToAIEVec pass since from the AIEVecOps.td it seems to not support even this configuration.

On further look I see that this is not because of vectorization pass, because the input linalg.generic to the vectorization part is :-

%19 = linalg.generic  ins(%pack_15, %pack_17 : tensor<1x1x8x16x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>)
                      outs(%18 : tensor<1x1x16x16x4x4xf32>) {
      ^bb0(%in: bf16, %in_19: bf16, %out: f32):
        %20 = arith.extf %in : bf16 to f32
        %21 = arith.extf %in_19 : bf16 to f32
        %22 = arith.mulf %20, %21 : f32
        %23 = arith.addf %out, %22 : f32
        linalg.yield %23 : f32
      } -> tensor<1x1x16x16x4x4xf32>

So the body of the linalg.generic contains those bf16->f32 conversion of the inputs.

jtuyls commented 1 week ago

On further look I see that this is not because of vectorization pass, because the input linalg.generic to the vectorization part is :-

%19 = linalg.generic  ins(%pack_15, %pack_17 : tensor<1x1x8x16x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>)
                      outs(%18 : tensor<1x1x16x16x4x4xf32>) {
      ^bb0(%in: bf16, %in_19: bf16, %out: f32):
        %20 = arith.extf %in : bf16 to f32
        %21 = arith.extf %in_19 : bf16 to f32
        %22 = arith.mulf %20, %21 : f32
        %23 = arith.addf %out, %22 : f32
        linalg.yield %23 : f32
      } -> tensor<1x1x16x16x4x4xf32>

So the body of the linalg.generic contains those bf16->f32 conversion of the inputs.

Ouch, any ideas on getting rid of that upcast?

Abhishek-Varma commented 1 week ago

Ouch, any ideas on getting rid of that upcast?

Another temporary pass for starters. :-P But we shouldn't do that.

There are e2e tests in CI that uses bf16 inputs and f32 output for non-objectfifo pack-peel (and vectorization is switched on by default). So, I'll need to inspect that e2e IR to comment here more though.

But just so that we aren't chasing any red-herrings here, I take it that the func.func change in this e2e IR (I shared in the above thread) looks okay now ?

It does linearization + the adjustments to the bf16 (2 bytes) -> i32 (4 bytes) and is linking to the corresponding bf16 objectfifos too.

jtuyls commented 1 week ago

But just so that we aren't chasing any red-herrings here, I take it that the func.func change in this e2e IR (I shared in the above thread) looks okay now ?

The offsets still look weird:

aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 1][1, 2, 64, 64][1, 16384, 256]) {id = 0 : i64, metadata = @obj0} : memref<16384xi32>
aiex.npu.dma_wait {symbol = @obj0}
aiex.npu.dma_memcpy_nd(0, 0, %arg1[0, 0, 0, 1][1, 2, 64, 64][1, 64, 128]) {id = 0 : i64, metadata = @obj1} : memref<16384xi32>
aiex.npu.dma_wait {symbol = @obj1}
aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 65][1, 2, 64, 64][1, 16384, 256]) {id = 0 : i64, metadata = @obj0} : memref<16384xi32>
aiex.npu.dma_wait {symbol = @obj0}
aiex.npu.dma_memcpy_nd(0, 0, %arg1[0, 0, 0, 8193][1, 2, 64, 64][1, 64, 128]) {id = 0 : i64, metadata = @obj1} : memref<16384xi32>
aiex.npu.dma_wait {symbol = @obj1}
aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 129][1, 2, 64, 64][1, 16384, 256]) {id = 0 : i64, metadata = @obj0} : memref<16384xi32>
aiex.npu.dma_wait {symbol = @obj0}
aiex.npu.dma_memcpy_nd(0, 0, %arg1[0, 0, 0, 16385][1, 2, 64, 64][1, 64, 128]) {id = 0 : i64, metadata = @obj1} : memref<16384xi32>
aiex.npu.dma_wait {symbol = @obj1}
aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 193][1, 2, 64, 64][1, 16384, 256]) {id = 0 : i64, metadata = @obj0} : memref<16384xi32>
aiex.npu.dma_wait {symbol = @obj0}
aiex.npu.dma_memcpy_nd(0, 0, %arg1[0, 0, 0, 24577][1, 2, 64, 64][1, 64, 128]) {id = 0 : i64, metadata = @obj1} : memref<16384xi32>
aiex.npu.dma_wait {symbol = @obj1}
aiex.npu.dma_memcpy_nd(0, 0, %arg2[0, 0, 0, 2][1, 1, 128, 128][1, 1, 128]) {id = 0 : i64, metadata = @obj10} : memref<8192xi32>

How do we get odd values (1/65/8193) etc? It should just be the offset before / 2?

yzhang93 commented 1 week ago

It seems to have yet another issue because this is the loop nest after vectorization :-

scf.for %arg5 = %c0 to %c16 step %c1 {
  scf.for %arg6 = %c0 to %c16 step %c1 {
    scf.for %arg7 = %c0 to %c8 step %c1 {
      %27 = vector.transfer_read %alloc_1[%c0, %c0, %arg7, %arg5, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x8x16x4x8xbf16, 2 : i32>, vector<1x1x1x1x4x8xbf16>
      %28 = vector.transfer_read %alloc[%c0, %c0, %arg6, %arg7, %c0, %c0], %cst {in_bounds = [true, true, true, true, true, true]} : memref<1x1x16x8x8x4xbf16, 2 : i32>, vector<1x1x1x1x8x4xbf16>
      %29 = vector.transfer_read %alloc_4[%arg3, %arg4, %arg6, %arg5, %c0, %c0], %cst_0 {in_bounds = [true, true, true, true, true, true]} : memref<2x2x16x16x4x4xf32, 2 : i32>, vector<1x1x1x1x4x4xf32>
      %30 = arith.extf %27 : vector<1x1x1x1x4x8xbf16> to vector<1x1x1x1x4x8xf32>
      %31 = arith.extf %28 : vector<1x1x1x1x8x4xbf16> to vector<1x1x1x1x8x4xf32>
      %32 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %30, %31, %29 : vector<1x1x1x1x4x8xf32>, vector<1x1x1x1x8x4xf32> into vector<1x1x1x1x4x4xf32>
      vector.transfer_write %32, %alloc_4[%arg3, %arg4, %arg6, %arg5, %c0, %c0] {in_bounds = [true, true, true, true, true, true]} : vector<1x1x1x1x4x4xf32>, memref<2x2x16x16x4x4xf32, 2 : i32>
    }
  }
}

This is not an issue, and AIEVec can handle it correctly. It combines arith.extf and vector.contract ops and generates aievec.matmul op with bf16 inputs and f32 output.

You can refer to this test https://github.com/Xilinx/mlir-aie/blob/54efffaa12dd4f0cb3cebbb7dbfa51bf78dc74f8/test/Conversion/VectorToAIEVec/test-contract.mlir#L106

Abhishek-Varma commented 1 week ago

Hi.

  1. Added fix for arith.extf op due to new accumulator type -distribute-core-and-objectfifo
  2. After discussing with @jtuyls I've added a fix for the offsets in the current revision - lower-to-aie
  3. Yesterday's func.func didn't have the size/stride metadata adjusted as well - so I added that too - lower-to-aie

Here's the current func.func :-

func.func @matmul_i8_i32_dispatch_0_matmul_128x128x256_bf16xbf16xf32(%arg0: memref<16384xi32>, %arg1: memref<16384xi32>, %arg2: memref<16384xi32>) {
  aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 0][0, 2, 64, 32][1, 8192, 128]) {id = 0 : i64, metadata = @obj0} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj0}
  aiex.npu.dma_memcpy_nd(0, 0, %arg1[0, 0, 0, 0][0, 2, 64, 32][1, 32, 64]) {id = 0 : i64, metadata = @obj1} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj1}
  aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 64][0, 2, 64, 32][1, 8192, 128]) {id = 0 : i64, metadata = @obj0} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj0}
  aiex.npu.dma_memcpy_nd(0, 0, %arg1[0, 0, 0, 8192][0, 2, 64, 32][1, 32, 64]) {id = 0 : i64, metadata = @obj1} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj1}
  aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 128][0, 2, 64, 32][1, 8192, 128]) {id = 0 : i64, metadata = @obj0} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj0}
  aiex.npu.dma_memcpy_nd(0, 0, %arg1[0, 0, 0, 16384][0, 2, 64, 32][1, 32, 64]) {id = 0 : i64, metadata = @obj1} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj1}
  aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 192][0, 2, 64, 32][1, 8192, 128]) {id = 0 : i64, metadata = @obj0} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj0}
  aiex.npu.dma_memcpy_nd(0, 0, %arg1[0, 0, 0, 24576][0, 2, 64, 32][1, 32, 64]) {id = 0 : i64, metadata = @obj1} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj1}
  aiex.npu.dma_memcpy_nd(0, 0, %arg2[0, 0, 0, 0][0, 0, 128, 64][1, 1, 64]) {id = 0 : i64, metadata = @obj10} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj10}
  return
}

NOTE: In the above snippet one more thing to observe is the OUTPUT, since we're using f32 accumulator now - the linearizing logic bailed out since the bitwidth is 32 and was creating a func.func (<16384xi32>,<16384xi32>,<128x128xf32>) which seemed wrong to me. Therefore I enforced linearization even if it is 32 bit && is NOT an IntegerType to get the above snippet.

On trying to generate a .vmfb from the above I got the following in AIEAssignBufferAddresses :-

'aie.tile' op allocated buffers exceeded available memory

Here is the IR log : e2e IR log

Adjusted the tiling/packing size of level 0 to -> <tile_sizes = [[64, 64], .....> and packedSizes = [32, 32, 32] ... but this failed in AIEObjectFifoStatefulTransform (CC: @jtuyls ) Here is the IR log of the updated tiling/packing level 0 : e2e IR new tiling/packing

Changes currently in the branch : avarma_test_emulator

CC: @MaheshRavishankar @jtuyls @yzhang93

jtuyls commented 3 days ago

Hi.

  1. Added fix for arith.extf op due to new accumulator type -distribute-core-and-objectfifo
  2. After discussing with @jtuyls I've added a fix for the offsets in the current revision - lower-to-aie
  3. Yesterday's func.func didn't have the size/stride metadata adjusted as well - so I added that too - lower-to-aie

Here's the current func.func :-

func.func @matmul_i8_i32_dispatch_0_matmul_128x128x256_bf16xbf16xf32(%arg0: memref<16384xi32>, %arg1: memref<16384xi32>, %arg2: memref<16384xi32>) {
  aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 0][0, 2, 64, 32][1, 8192, 128]) {id = 0 : i64, metadata = @obj0} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj0}
  aiex.npu.dma_memcpy_nd(0, 0, %arg1[0, 0, 0, 0][0, 2, 64, 32][1, 32, 64]) {id = 0 : i64, metadata = @obj1} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj1}
  aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 64][0, 2, 64, 32][1, 8192, 128]) {id = 0 : i64, metadata = @obj0} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj0}
  aiex.npu.dma_memcpy_nd(0, 0, %arg1[0, 0, 0, 8192][0, 2, 64, 32][1, 32, 64]) {id = 0 : i64, metadata = @obj1} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj1}
  aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 128][0, 2, 64, 32][1, 8192, 128]) {id = 0 : i64, metadata = @obj0} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj0}
  aiex.npu.dma_memcpy_nd(0, 0, %arg1[0, 0, 0, 16384][0, 2, 64, 32][1, 32, 64]) {id = 0 : i64, metadata = @obj1} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj1}
  aiex.npu.dma_memcpy_nd(0, 0, %arg0[0, 0, 0, 192][0, 2, 64, 32][1, 8192, 128]) {id = 0 : i64, metadata = @obj0} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj0}
  aiex.npu.dma_memcpy_nd(0, 0, %arg1[0, 0, 0, 24576][0, 2, 64, 32][1, 32, 64]) {id = 0 : i64, metadata = @obj1} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj1}
  aiex.npu.dma_memcpy_nd(0, 0, %arg2[0, 0, 0, 0][0, 0, 128, 64][1, 1, 64]) {id = 0 : i64, metadata = @obj10} : memref<16384xi32>
  aiex.npu.dma_wait {symbol = @obj10}
  return
}

NOTE: In the above snippet one more thing to observe is the OUTPUT, since we're using f32 accumulator now - the linearizing logic bailed out since the bitwidth is 32 and was creating a func.func (<16384xi32>,<16384xi32>,<128x128xf32>) which seemed wrong to me. Therefore I enforced linearization even if it is 32 bit && is NOT an IntegerType to get the above snippet.

On trying to generate a .vmfb from the above I got the following in AIEAssignBufferAddresses :-

'aie.tile' op allocated buffers exceeded available memory

Here is the IR log : e2e IR log

Adjusted the tiling/packing size of level 0 to -> <tile_sizes = [[64, 64], .....> and packedSizes = [32, 32, 32] ... but this failed in AIEObjectFifoStatefulTransform (CC: @jtuyls ) Here is the IR log of the updated tiling/packing level 0 : e2e IR new tiling/packing

Changes currently in the branch : avarma_test_emulator

CC: @MaheshRavishankar @jtuyls @yzhang93

@Abhishek-Varma I tried to replicate this error by executing the amdaie-objectFifo-stateful-transform pass on the IR right before it and for me the pass succeeded. Could you try running this as well with latest iree-amd-aie?

Command:

${IREE_BUILD_DIR}/tools/iree-opt matmul_vec.mlir --mlir-print-ir-before-all --amdaie-objectFifo-stateful-transform

with matmul_vec.mlir containing the IR dump right before AIEObjectFifoStatefulTransform in the snippet you shared: https://gist.githubusercontent.com/Abhishek-Varma/90fbe66ec4aabb5a3da410885615c5f3/raw/184b9eac151d4c20a54a4a73d30904f29c6fa597/input.mlir

Gist: https://gist.github.com/jtuyls/0d46284d9d3c5780bd298ad7de4d88a3