onnx / onnx-mlir

Representation and Reference Lowering of ONNX Models in MLIR Compiler Infrastructure
Apache License 2.0
741 stars 314 forks source link

Unroll and Jam inside KrnlToAffine #770

Closed AlexandreEichenberger closed 3 years ago

AlexandreEichenberger commented 3 years ago

I created a custom branch to evaluate what is happening in the unroll and jam. #769

The original onnx program

func @test_gemm_16_16_16(%a: tensor<16x16xf32>, %b: tensor<16x16xf32>, %c: tensor<16xf32>) 
    -> tensor<16x16xf32> {
  %0 = "onnx.Gemm"(%a, %b, %c) : (tensor<16x16xf32>, tensor<16x16xf32>, tensor<16xf32>) 
    -> tensor<16x16xf32>
  "std.return"(%0) : (tensor<16x16xf32>) -> ()
}

can be lowered to affine with

onnx-mlir-opt onnx-gemm-16-16-16.mlir --convert-onnx-to-krnl --convert-krnl-to-affine

and when the #define UNROLL_IT 1 is set to 1, I get the crash below.

PLEASE submit a bug report to https://bugs.llvm.org/ and include the crash backtrace.
Stack dump:
0.  Program arguments: onnx-mlir-opt onnx-gemm-16-16-16.mlir --convert-onnx-to-krnl --convert-krnl-to-affine
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0  onnx-mlir-opt            0x000000010704dffb llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 43
1  onnx-mlir-opt            0x000000010704cc78 llvm::sys::RunSignalHandlers() + 248
2  onnx-mlir-opt            0x000000010704e9f7 SignalHandler(int) + 295
3  libsystem_platform.dylib 0x00007fff204eed7d _sigtramp + 29
4  libsystem_platform.dylib 0x00007ffcbcd0d320 _sigtramp + 18446744063450408384
5  onnx-mlir-opt            0x0000000106f5815a mlir::OperationName::getDialectNamespace() const + 74
6  onnx-mlir-opt            0x000000010680cc32 mlir::ConversionTarget::getOpInfo(mlir::OperationName) const + 162
7  onnx-mlir-opt            0x000000010680d028 mlir::ConversionTarget::isLegal(mlir::Operation*) const + 56
8  onnx-mlir-opt            0x0000000106815e7b (anonymous namespace)::OperationLegalizer::legalize(mlir::Operation*, mlir::ConversionPatternRewriter&) + 1211
9  onnx-mlir-opt            0x00000001068184c1 mlir::LogicalResult llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>::callback_fn<(anonymous namespace)::OperationLegalizer::legalizeWithPattern(mlir::Operation*, mlir::ConversionPatternRewriter&)::$_14>(long, mlir::Pattern const&) + 529
10 onnx-mlir-opt            0x000000010686f440 mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) + 1776
11 onnx-mlir-opt            0x000000010681699c (anonymous namespace)::OperationLegalizer::legalize(mlir::Operation*, mlir::ConversionPatternRewriter&) + 4060
12 onnx-mlir-opt            0x000000010680f4bc (anonymous namespace)::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>) + 1580
13 onnx-mlir-opt            0x0000000106811369 mlir::applyPartialConversion(mlir::Operation*, mlir::ConversionTarget&, mlir::FrozenRewritePatternSet const&, llvm::DenseSet<mlir::Operation*, llvm::DenseMapInfo<mlir::Operation*> >*) + 73
14 onnx-mlir-opt            0x00000001063cc010 (anonymous namespace)::ConvertKrnlToAffinePass::runOnFunction() + 1776
15 onnx-mlir-opt            0x00000001069dc5c7 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 519
16 onnx-mlir-opt            0x00000001069dcb86 mlir::detail::OpToOpPassAdaptor::runPipeline(llvm::iterator_range<llvm::pointee_iterator<std::__1::unique_ptr<mlir::Pass, std::__1::default_delete<mlir::Pass> >*, mlir::Pass> >, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 134
17 onnx-mlir-opt            0x00000001069e135d auto mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_8::operator()<std::__1::pair<mlir::Operation*, mlir::AnalysisManager> >(std::__1::pair<mlir::Operation*, mlir::AnalysisManager>&) const + 381
18 onnx-mlir-opt            0x00000001069e0d1b mlir::LogicalResult mlir::failableParallelForEach<std::__1::__wrap_iter<std::__1::pair<mlir::Operation*, mlir::AnalysisManager>*>, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_8&>(mlir::MLIRContext*, std::__1::__wrap_iter<std::__1::pair<mlir::Operation*, mlir::AnalysisManager>*>, std::__1::__wrap_iter<std::__1::pair<mlir::Operation*, mlir::AnalysisManager>*>, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_8&) + 139
19 onnx-mlir-opt            0x00000001069dd91d mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool) + 1341
20 onnx-mlir-opt            0x00000001069dc79c mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 988
21 onnx-mlir-opt            0x00000001069dcb86 mlir::detail::OpToOpPassAdaptor::runPipeline(llvm::iterator_range<llvm::pointee_iterator<std::__1::unique_ptr<mlir::Pass, std::__1::default_delete<mlir::Pass> >*, mlir::Pass> >, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 134
22 onnx-mlir-opt            0x00000001069de4f1 mlir::PassManager::run(mlir::Operation*) + 689
23 onnx-mlir-opt            0x00000001065c76bd performActions(llvm::raw_ostream&, bool, bool, llvm::SourceMgr&, mlir::MLIRContext*, mlir::PassPipelineCLParser const&) + 525
24 onnx-mlir-opt            0x00000001065c586e processBuffer(llvm::raw_ostream&, std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer> >, bool, bool, bool, bool, mlir::PassPipelineCLParser const&, mlir::DialectRegistry&) + 622
25 onnx-mlir-opt            0x00000001065c55d4 mlir::MlirOptMain(llvm::raw_ostream&, std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer> >, mlir::PassPipelineCLParser const&, mlir::DialectRegistry&, bool, bool, bool, bool, bool) + 180
26 onnx-mlir-opt            0x0000000105d57d2c main + 1468
27 libdyld.dylib            0x00007fff204c4f5d start + 1

I added code to dump the function before returning success from the lower matmul pattern. The code is below. I just commented out the matmul operation that was replaced by the other code. Its a mixture of affine and krnl.load/store

func @test_gemm_16_16_16(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf32>, %arg2: memref<16xf32>) -> memref<16x16xf32> {
  %c0 = constant 0 : index
  %0 = memref.alloc() {alignment = 128 : i64} : memref<512x128xf32>
  %1 = memref.alloc() {alignment = 128 : i64} : memref<64x512xf32>
  %2 = memref.alloc() {alignment = 128 : i64} : memref<16x16xf32>
  %c16 = constant 16 : index
  %c16_0 = constant 16 : index
  %c16_1 = constant 16 : index
  %c16_2 = constant 16 : index
  %c1 = constant 1 : index
  %c16_3 = constant 16 : index
  %cst = constant 1.000000e+00 : f32
  %cst_4 = constant 1.000000e+00 : f32
  %cst_5 = constant 0.000000e+00 : f32
  %c0_6 = constant 0 : index
  affine.for %arg3 = 0 to 16 {
    affine.for %arg4 = 0 to 16 {
      affine.store %cst_5, %2[%arg3, %arg4] : memref<16x16xf32>
      krnl.store %cst_5, %2[%arg3, %arg4] : memref<16x16xf32>
      affine.yield
    }
  }
  affine.for %arg3 = 0 to 16 step 128 {
    affine.for %arg4 = 0 to 16 step 512 {
      %c0_10 = constant 0 : index
      %c16_11 = constant 16 : index
      %c512 = constant 512 : index
      %3 = affine.apply affine_map<(d0) -> (-d0 + 16)>(%arg4)
      %4 = affine.min affine_map<(d0) -> (-d0 + 16, 512)>(%arg4)
      %c1_12 = constant 1 : index
      %c16_13 = constant 16 : index
      %c128 = constant 128 : index
      %5 = affine.apply affine_map<(d0, d1) -> (-d1 + 16)>(%arg4, %arg3)
      %6 = affine.min affine_map<(d0, d1) -> (-d1 + 16, 128)>(%arg4, %arg3)
      %c1_14 = constant 1 : index
      affine.for %arg5 = 0 to min affine_map<(d0) -> (-d0 + 16, 512)>(%arg4) {
        affine.for %arg6 = 0 to min affine_map<(d0, d1) -> (-d1 + 16, 128)>(%arg4, %arg3) {
          %7 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%arg5, %arg4)
          %8 = affine.apply affine_map<(d0, d1, d2, d3) -> (d2 + d3)>(%arg5, %arg4, %arg6, %arg3)
          %9 = affine.load %arg1[%7, %8] : memref<16x16xf32>
          %10 = krnl.load %arg1[%7, %8] : memref<16x16xf32>
          affine.store %10, %0[%arg5, %arg6] : memref<512x128xf32>
          krnl.store %10, %0[%arg5, %arg6] : memref<512x128xf32>
        }
      }
      krnl.copy_to_tile_buffer %0, %arg1[%arg4, %arg3], %cst_5 {padToNext = [], tileSize = [], transpose = false} : memref<512x128xf32>, memref<16x16xf32>
      affine.for %arg5 = 0 to 16 step 64 {
        %c0_15 = constant 0 : index
        %c16_16 = constant 16 : index
        %c64 = constant 64 : index
        %7 = affine.apply affine_map<(d0) -> (-d0 + 16)>(%arg5)
        %8 = affine.min affine_map<(d0) -> (-d0 + 16, 64)>(%arg5)
        %c1_17 = constant 1 : index
        %c16_18 = constant 16 : index
        %c512_19 = constant 512 : index
        %9 = affine.apply affine_map<(d0, d1) -> (-d1 + 16)>(%arg5, %arg4)
        %10 = affine.min affine_map<(d0, d1) -> (-d1 + 16, 512)>(%arg5, %arg4)
        %c1_20 = constant 1 : index
        affine.for %arg6 = 0 to min affine_map<(d0) -> (-d0 + 16, 64)>(%arg5) {
          affine.for %arg7 = 0 to min affine_map<(d0, d1) -> (-d1 + 16, 512)>(%arg5, %arg4) {
            %11 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%arg6, %arg5)
            %12 = affine.apply affine_map<(d0, d1, d2, d3) -> (d2 + d3)>(%arg6, %arg5, %arg7, %arg4)
            %13 = affine.load %arg0[%11, %12] : memref<16x16xf32>
            %14 = krnl.load %arg0[%11, %12] : memref<16x16xf32>
            affine.store %14, %1[%arg6, %arg7] : memref<64x512xf32>
            krnl.store %14, %1[%arg6, %arg7] : memref<64x512xf32>
          }
        }
        krnl.copy_to_tile_buffer %1, %arg0[%arg5, %arg4], %cst_5 {padToNext = [], tileSize = [], transpose = false} : memref<64x512xf32>, memref<16x16xf32>
        affine.for %arg6 = affine_map<(d0) -> (d0)>(%arg3) to affine_map<(d0) -> (d0 + 16)>(%arg3) step 8 {
          affine.for %arg7 = affine_map<(d0) -> (d0)>(%arg5) to affine_map<(d0) -> (d0 + 16)>(%arg5) step 4 {
            %c64_21 = constant 64 : index
            %c512_22 = constant 512 : index
            %c512_23 = constant 512 : index
            %c128_24 = constant 128 : index
            %c16_25 = constant 16 : index
            %c16_26 = constant 16 : index
            %c4 = constant 4 : index
            %c8 = constant 8 : index
            %c512_27 = constant 512 : index
            %c16_28 = constant 16 : index
            %c16_29 = constant 16 : index
            %c16_30 = constant 16 : index
            %11 = affine.apply affine_map<(d0, d1) -> (d0 - d1)>(%arg7, %arg5)
            %12 = affine.apply affine_map<(d0, d1, d2, d3) -> (d2 - d3)>(%arg7, %arg5, %arg4, %arg4)
            %13 = affine.apply affine_map<(d0, d1, d2, d3, d4) -> (d2 - d4)>(%arg7, %arg5, %arg4, %arg4, %arg4)
            %14 = affine.apply affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5 - d6)>(%arg7, %arg5, %arg4, %arg4, %arg4, %arg6, %arg3)
            %c0_31 = constant 0 : index
            %15 = affine.apply affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0)>(%arg7, %arg5, %arg4, %arg4, %arg4, %arg6, %arg3)
            %c0_32 = constant 0 : index
            %16 = affine.apply affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5)>(%arg7, %arg5, %arg4, %arg4, %arg4, %arg6, %arg3)
            %17 = affine.apply affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ((d5 - d6) floordiv 8)>(%arg7, %arg5, %arg4, %arg4, %arg4, %arg6, %arg3)
            %18 = affine.apply affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5 floordiv 8)>(%arg7, %arg5, %arg4, %arg4, %arg4, %arg6, %arg3)
            %c1_33 = constant 1 : index
            %c1_34 = constant 1 : index
            %c-496 = constant -496 : index
            %19 = affine.apply affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (-d2 - 496)>(%arg7, %arg5, %arg4, %arg4, %arg4, %arg6, %arg3)
            %20 = affine.apply affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (-d2 + 16)>(%arg7, %arg5, %arg4, %arg4, %arg4, %arg6, %arg3)
            %21 = affine.min affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (-d2 + 16, 512)>(%arg7, %arg5, %arg4, %arg4, %arg4, %arg6, %arg3)
            %c0_35 = constant 0 : index
            affine.if affine_set<(d0, d1, d2, d3, d4, d5, d6) : (1 >= 0, 1 >= 0, -d2 - 496 >= 0)>(%arg7, %arg5, %arg4, %arg4, %arg4, %arg6, %arg3) {
              %22 = krnl.vector_type_cast %0 : memref<512x128xf32> to memref<512x16xvector<8xf32>>
              %23 = krnl.vector_type_cast %2 : memref<16x16xf32> to memref<16x2xvector<8xf32>>
              %24 = memref.alloca() {alignment = 64 : i64} : memref<vector<8xf32>>
              %c0_36 = constant 0 : index
              %25 = addi %c0, %15 : index
              %26 = krnl.load %23[%25, %18] : memref<16x2xvector<8xf32>>
              krnl.store %26, %24[] : memref<vector<8xf32>>
              %27 = affine.apply affine_map<(d0) -> (d0 + 1)>(%c0)
              %28 = addi %27, %15 : index
              %29 = krnl.load %23[%28, %18] : memref<16x2xvector<8xf32>>
              krnl.store %29, %24[] : memref<vector<8xf32>>
              %30 = affine.apply affine_map<(d0) -> (d0 + 2)>(%c0)
              %31 = addi %30, %15 : index
              %32 = krnl.load %23[%31, %18] : memref<16x2xvector<8xf32>>
              krnl.store %32, %24[] : memref<vector<8xf32>>
              %33 = affine.apply affine_map<(d0) -> (d0 + 3)>(%c0)
              %34 = addi %33, %15 : index
              %35 = krnl.load %23[%34, %18] : memref<16x2xvector<8xf32>>
              krnl.store %35, %24[] : memref<vector<8xf32>>
              affine.for %arg8 = 0 to 512 {
                %43 = addi %c0, %11 : index
                %44 = addi %arg8, %12 : index
                %45 = krnl.load %1[%43, %44] : memref<64x512xf32>
                %46 = vector.broadcast %45 : f32 to vector<8xf32>
                %47 = addi %arg8, %13 : index
                %48 = krnl.load %22[%47, %17] : memref<512x16xvector<8xf32>>
                %49 = krnl.load %24[] : memref<vector<8xf32>>
                %50 = vector.fma %46, %48, %49 : vector<8xf32>
                krnl.store %50, %24[] : memref<vector<8xf32>>
                %51 = affine.apply affine_map<(d0) -> (d0 + 1)>(%c0)
                %52 = addi %51, %11 : index
                %53 = addi %arg8, %12 : index
                %54 = krnl.load %1[%52, %53] : memref<64x512xf32>
                %55 = vector.broadcast %54 : f32 to vector<8xf32>
                %56 = addi %arg8, %13 : index
                %57 = krnl.load %22[%56, %17] : memref<512x16xvector<8xf32>>
                %58 = krnl.load %24[] : memref<vector<8xf32>>
                %59 = vector.fma %55, %57, %58 : vector<8xf32>
                krnl.store %59, %24[] : memref<vector<8xf32>>
                %60 = affine.apply affine_map<(d0) -> (d0 + 2)>(%c0)
                %61 = addi %60, %11 : index
                %62 = addi %arg8, %12 : index
                %63 = krnl.load %1[%61, %62] : memref<64x512xf32>
                %64 = vector.broadcast %63 : f32 to vector<8xf32>
                %65 = addi %arg8, %13 : index
                %66 = krnl.load %22[%65, %17] : memref<512x16xvector<8xf32>>
                %67 = krnl.load %24[] : memref<vector<8xf32>>
                %68 = vector.fma %64, %66, %67 : vector<8xf32>
                krnl.store %68, %24[] : memref<vector<8xf32>>
                %69 = affine.apply affine_map<(d0) -> (d0 + 3)>(%c0)
                %70 = addi %69, %11 : index
                %71 = addi %arg8, %12 : index
                %72 = krnl.load %1[%70, %71] : memref<64x512xf32>
                %73 = vector.broadcast %72 : f32 to vector<8xf32>
                %74 = addi %arg8, %13 : index
                %75 = krnl.load %22[%74, %17] : memref<512x16xvector<8xf32>>
                %76 = krnl.load %24[] : memref<vector<8xf32>>
                %77 = vector.fma %73, %75, %76 : vector<8xf32>
                krnl.store %77, %24[] : memref<vector<8xf32>>
              }
              %36 = krnl.load %24[] : memref<vector<8xf32>>
              krnl.store %36, %23[%25, %18] : memref<16x2xvector<8xf32>>
              %37 = affine.apply affine_map<(d0) -> (d0 + 1)>(%c0)
              %38 = krnl.load %24[] : memref<vector<8xf32>>
              krnl.store %38, %23[%28, %18] : memref<16x2xvector<8xf32>>
              %39 = affine.apply affine_map<(d0) -> (d0 + 2)>(%c0)
              %40 = krnl.load %24[] : memref<vector<8xf32>>
              krnl.store %40, %23[%31, %18] : memref<16x2xvector<8xf32>>
              %41 = affine.apply affine_map<(d0) -> (d0 + 3)>(%c0)
              %42 = krnl.load %24[] : memref<vector<8xf32>>
              krnl.store %42, %23[%34, %18] : memref<16x2xvector<8xf32>>
            } else {
              affine.if affine_set<(d0, d1, d2, d3, d4, d5, d6) : (1 >= 0)>(%arg7, %arg5, %arg4, %arg4, %arg4, %arg6, %arg3) {
                %22 = krnl.vector_type_cast %0 : memref<512x128xf32> to memref<512x16xvector<8xf32>>
                %23 = krnl.vector_type_cast %2 : memref<16x16xf32> to memref<16x2xvector<8xf32>>
                %24 = memref.alloca() {alignment = 64 : i64} : memref<vector<8xf32>>
                %c0_36 = constant 0 : index
                affine.for %arg8 = 0 to 4 {
                  %25 = addi %arg8, %15 : index
                  %26 = krnl.load %23[%25, %18] : memref<16x2xvector<8xf32>>
                  krnl.store %26, %24[] : memref<vector<8xf32>>
                  affine.for %arg9 = 0 to min affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (-d2 + 16, 512)>(%arg7, %arg5, %arg4, %arg4, %arg4, %arg6, %arg3) {
                    %28 = addi %arg8, %11 : index
                    %29 = addi %arg9, %12 : index
                    %30 = krnl.load %1[%28, %29] : memref<64x512xf32>
                    %31 = vector.broadcast %30 : f32 to vector<8xf32>
                    %32 = addi %arg9, %13 : index
                    %33 = krnl.load %22[%32, %17] : memref<512x16xvector<8xf32>>
                    %34 = krnl.load %24[] : memref<vector<8xf32>>
                    %35 = vector.fma %31, %33, %34 : vector<8xf32>
                    krnl.store %35, %24[] : memref<vector<8xf32>>
                  }
                  %27 = krnl.load %24[] : memref<vector<8xf32>>
                  krnl.store %27, %23[%25, %18] : memref<16x2xvector<8xf32>>
                }
              } else {
              }
            }
            // krnl.matmul %1[%arg5, %arg4], %0[%arg4, %arg3], %2[%c0_6, %c0_6], (), (%arg7, %arg6, %arg4), (%c16, %c16_2, %c16_0) {aTileSize = [], bTileSize = [], cTileSize = [], computeTileSize = [4, 8, 512], overcompute = false, simdize = true, unroll = true} : memref<64x512xf32>, memref<512x128xf32>, memref<16x16xf32>, ()
          }
        }
      }
    }
  }
  %c16_7 = constant 16 : index
  %c1_8 = constant 1 : index
  %true = constant true
  %c0_9 = constant 0 : index
  affine.for %arg3 = 0 to 16 {
    affine.for %arg4 = 0 to 16 {
      %3 = krnl.load %2[%arg3, %arg4] : memref<16x16xf32>
      %4 = krnl.load %arg2[%arg4] : memref<16xf32>
      %5 = addf %3, %4 : f32
      krnl.store %5, %2[%arg3, %arg4] : memref<16x16xf32>
    }
  }
  memref.dealloc %1 : memref<64x512xf32>
  memref.dealloc %0 : memref<512x128xf32>
  return %2 : memref<16x16xf32>
}

At a cursory glance, the code looks ok. I did feed it back to onnx-mlir-opt asking again to lower krnl to affine and it results in a code below. Without any crashes.

#map0 = affine_map<(d0) -> (-d0 + 16, 512)>
#map1 = affine_map<(d0) -> (-d0 + 16, 128)>
#map2 = affine_map<(d0) -> (-d0 + 16, 64)>
#map3 = affine_map<(d0) -> (d0)>
#map4 = affine_map<(d0) -> (d0 + 16)>
#map5 = affine_map<(d0, d1) -> (d0 - d1)>
#map6 = affine_map<(d0, d1) -> ((d0 - d1) floordiv 8)>
#map7 = affine_map<(d0) -> (d0 floordiv 8)>
#set0 = affine_set<(d0) : (1 >= 0, 1 >= 0, -d0 - 496 >= 0)>
#set1 = affine_set<() : (1 >= 0)>
module  {
  func @test_gemm_16_16_16(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf32>, %arg2: memref<16xf32>) -> memref<16x16xf32> {
    %c3 = constant 3 : index
    %c2 = constant 2 : index
    %c1 = constant 1 : index
    %cst = constant 0.000000e+00 : f32
    %0 = memref.alloc() {alignment = 128 : i64} : memref<512x128xf32>
    %1 = memref.alloc() {alignment = 128 : i64} : memref<64x512xf32>
    %2 = memref.alloc() {alignment = 128 : i64} : memref<16x16xf32>
    affine.for %arg3 = 0 to 16 {
      affine.for %arg4 = 0 to 16 {
        affine.store %cst, %2[%arg3, %arg4] : memref<16x16xf32>
        affine.store %cst, %2[%arg3, %arg4] : memref<16x16xf32>
      }
    }
    affine.for %arg3 = 0 to 16 step 128 {
      affine.for %arg4 = 0 to 16 step 512 {
        affine.for %arg5 = 0 to min #map0(%arg4) {
          affine.for %arg6 = 0 to min #map1(%arg3) {
            %3 = affine.load %arg1[%arg5 + %arg4, %arg6 + %arg3] : memref<16x16xf32>
            affine.store %3, %0[%arg5, %arg6] : memref<512x128xf32>
            affine.store %3, %0[%arg5, %arg6] : memref<512x128xf32>
          }
        }
        affine.for %arg5 = 0 to min #map0(%arg4) {
          affine.for %arg6 = 0 to min #map1(%arg3) {
            %3 = affine.load %arg1[%arg5 + %arg4, %arg6 + %arg3] : memref<16x16xf32>
            affine.store %3, %0[%arg5, %arg6] : memref<512x128xf32>
          }
        }
        affine.for %arg5 = 0 to 16 step 64 {
          affine.for %arg6 = 0 to min #map2(%arg5) {
            affine.for %arg7 = 0 to min #map0(%arg4) {
              %3 = affine.load %arg0[%arg6 + %arg5, %arg7 + %arg4] : memref<16x16xf32>
              affine.store %3, %1[%arg6, %arg7] : memref<64x512xf32>
              affine.store %3, %1[%arg6, %arg7] : memref<64x512xf32>
            }
          }
          affine.for %arg6 = 0 to min #map2(%arg5) {
            affine.for %arg7 = 0 to min #map0(%arg4) {
              %3 = affine.load %arg0[%arg6 + %arg5, %arg7 + %arg4] : memref<16x16xf32>
              affine.store %3, %1[%arg6, %arg7] : memref<64x512xf32>
            }
          }
          affine.for %arg6 = #map3(%arg3) to #map4(%arg3) step 8 {
            affine.for %arg7 = #map3(%arg5) to #map4(%arg5) step 4 {
              %3 = affine.apply #map5(%arg7, %arg5)
              %4 = affine.apply #map6(%arg6, %arg3)
              %5 = affine.apply #map7(%arg6)
              affine.if #set0(%arg4) {
                %6 = krnl.vector_type_cast %0 : memref<512x128xf32> to memref<512x16xvector<8xf32>>
                %7 = krnl.vector_type_cast %2 : memref<16x16xf32> to memref<16x2xvector<8xf32>>
                %8 = memref.alloca() {alignment = 64 : i64} : memref<vector<8xf32>>
                %9 = memref.load %7[%arg7, %5] : memref<16x2xvector<8xf32>>
                affine.store %9, %8[] : memref<vector<8xf32>>
                %10 = addi %arg7, %c1 : index
                %11 = memref.load %7[%10, %5] : memref<16x2xvector<8xf32>>
                affine.store %11, %8[] : memref<vector<8xf32>>
                %12 = addi %arg7, %c2 : index
                %13 = memref.load %7[%12, %5] : memref<16x2xvector<8xf32>>
                affine.store %13, %8[] : memref<vector<8xf32>>
                %14 = addi %arg7, %c3 : index
                %15 = memref.load %7[%14, %5] : memref<16x2xvector<8xf32>>
                affine.store %15, %8[] : memref<vector<8xf32>>
                affine.for %arg8 = 0 to 512 {
                  %20 = memref.load %1[%3, %arg8] : memref<64x512xf32>
                  %21 = vector.broadcast %20 : f32 to vector<8xf32>
                  %22 = memref.load %6[%arg8, %4] : memref<512x16xvector<8xf32>>
                  %23 = affine.load %8[] : memref<vector<8xf32>>
                  %24 = vector.fma %21, %22, %23 : vector<8xf32>
                  affine.store %24, %8[] : memref<vector<8xf32>>
                  %25 = addi %3, %c1 : index
                  %26 = memref.load %1[%25, %arg8] : memref<64x512xf32>
                  %27 = vector.broadcast %26 : f32 to vector<8xf32>
                  %28 = memref.load %6[%arg8, %4] : memref<512x16xvector<8xf32>>
                  %29 = affine.load %8[] : memref<vector<8xf32>>
                  %30 = vector.fma %27, %28, %29 : vector<8xf32>
                  affine.store %30, %8[] : memref<vector<8xf32>>
                  %31 = addi %3, %c2 : index
                  %32 = memref.load %1[%31, %arg8] : memref<64x512xf32>
                  %33 = vector.broadcast %32 : f32 to vector<8xf32>
                  %34 = memref.load %6[%arg8, %4] : memref<512x16xvector<8xf32>>
                  %35 = affine.load %8[] : memref<vector<8xf32>>
                  %36 = vector.fma %33, %34, %35 : vector<8xf32>
                  affine.store %36, %8[] : memref<vector<8xf32>>
                  %37 = addi %3, %c3 : index
                  %38 = memref.load %1[%37, %arg8] : memref<64x512xf32>
                  %39 = vector.broadcast %38 : f32 to vector<8xf32>
                  %40 = memref.load %6[%arg8, %4] : memref<512x16xvector<8xf32>>
                  %41 = affine.load %8[] : memref<vector<8xf32>>
                  %42 = vector.fma %39, %40, %41 : vector<8xf32>
                  affine.store %42, %8[] : memref<vector<8xf32>>
                }
                %16 = affine.load %8[] : memref<vector<8xf32>>
                memref.store %16, %7[%arg7, %5] : memref<16x2xvector<8xf32>>
                %17 = affine.load %8[] : memref<vector<8xf32>>
                memref.store %17, %7[%10, %5] : memref<16x2xvector<8xf32>>
                %18 = affine.load %8[] : memref<vector<8xf32>>
                memref.store %18, %7[%12, %5] : memref<16x2xvector<8xf32>>
                %19 = affine.load %8[] : memref<vector<8xf32>>
                memref.store %19, %7[%14, %5] : memref<16x2xvector<8xf32>>
              } else {
                affine.if #set1() {
                  %6 = krnl.vector_type_cast %0 : memref<512x128xf32> to memref<512x16xvector<8xf32>>
                  %7 = krnl.vector_type_cast %2 : memref<16x16xf32> to memref<16x2xvector<8xf32>>
                  %8 = memref.alloca() {alignment = 64 : i64} : memref<vector<8xf32>>
                  affine.for %arg8 = 0 to 4 {
                    %9 = addi %arg8, %arg7 : index
                    %10 = memref.load %7[%9, %5] : memref<16x2xvector<8xf32>>
                    affine.store %10, %8[] : memref<vector<8xf32>>
                    affine.for %arg9 = 0 to min #map0(%arg4) {
                      %12 = addi %arg8, %3 : index
                      %13 = memref.load %1[%12, %arg9] : memref<64x512xf32>
                      %14 = vector.broadcast %13 : f32 to vector<8xf32>
                      %15 = memref.load %6[%arg9, %4] : memref<512x16xvector<8xf32>>
                      %16 = affine.load %8[] : memref<vector<8xf32>>
                      %17 = vector.fma %14, %15, %16 : vector<8xf32>
                      affine.store %17, %8[] : memref<vector<8xf32>>
                    }
                    %11 = affine.load %8[] : memref<vector<8xf32>>
                    memref.store %11, %7[%9, %5] : memref<16x2xvector<8xf32>>
                  }
                }
              }
            }
          }
        }
      }
    }
    affine.for %arg3 = 0 to 16 {
      affine.for %arg4 = 0 to 16 {
        %3 = affine.load %2[%arg3, %arg4] : memref<16x16xf32>
        %4 = affine.load %arg2[%arg4] : memref<16xf32>
        %5 = addf %3, %4 : f32
        affine.store %5, %2[%arg3, %arg4] : memref<16x16xf32>
      }
    }
    memref.dealloc %1 : memref<64x512xf32>
    memref.dealloc %0 : memref<512x128xf32>
    return %2 : memref<16x16xf32>
  }
}

hi alex, starting Step 1 to run convert krnl to affine
hi alex, starting Step 2 to run convert krnl to affine
hi alex, starting Step 3 to run convert krnl to affine
hi alex, starting Step 4 to run convert krnl to affine
hi alex, starting Step 5 to run convert krnl to affine
hi alex, starting Step 6 to run convert krnl to affine
hi alex, done to run convert krnl to affine

I don't know at this time how to debug it.

tungld commented 3 years ago

It seems that you are using unrollAndJam in MLIR which does know about krnl.load and krnl.store. So during the lowering where krnl.load and krnl.store have not lowered to affine.load and affine.store, they become illegal to unrollAndJam. This likely depends on the order of evaluation.

AlexandreEichenberger commented 3 years ago

@tungld , we found a way to make it work... will push the solution soon