llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
29.12k stars 12.01k forks source link

[mlir] [Vector] implement vector.extract for 2-d vector type (or flatten to 1-d an array type) #87667

Open LLITCHEV opened 7 months ago

LLITCHEV commented 7 months ago

The VectorConvertToLLVM pass is not converting the %85 = vector.extract %75[0, %83] : i1 from vector<1x16xi1> to LLVM dialect. It appers the conversion code converts only 1-d types.

module {
  func.func @custom_call_topk_tuple_16_dispatch_0_topk_1x32xf32() {
    %cst = arith.constant dense<false> : vector<1xi1>
    %cst_0 = arith.constant dense<true> : vector<16xi1>
    %c0_i32 = arith.constant 0 : i32
    %false = arith.constant false
    %c4 = arith.constant 4 : index
    %cst_1 = arith.constant 0.000000e+00 : f32
    %true = arith.constant true
    %c16 = arith.constant 16 : index
    %c32 = arith.constant 32 : index
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %c64 = arith.constant 64 : index
    %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<1x32xf32>
    memref.assume_alignment %0, 64 : memref<1x32xf32>
    %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<1x4xf32>
    memref.assume_alignment %1, 64 : memref<1x4xf32>
    %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c64) : memref<1x4xi32, strided<[4, 1], offset: 16>>
    memref.assume_alignment %2, 64 : memref<1x4xi32, strided<[4, 1], offset: 16>>
    %workgroup_id_x = hal.interface.workgroup.id[0] : index
    %workgroup_count_x = hal.interface.workgroup.count[0] : index
    cf.br ^bb1(%workgroup_id_x : index)
  ^bb1(%3: index):  // 2 preds: ^bb0, ^bb24
    %4 = arith.cmpi slt, %3, %c1 : index
    cf.cond_br %4, ^bb2(%c0, %true, %cst_1, %c0 : index, i1, f32, index), ^bb25
  ^bb2(%5: index, %6: i1, %7: f32, %8: index):  // 2 preds: ^bb1, ^bb23
    %9 = arith.cmpi slt, %5, %c32 : index
    cf.cond_br %9, ^bb3, ^bb24
  ^bb3:  // pred: ^bb2
    %10 = arith.select %6, %c1, %8 : index
    cf.cond_br %6, ^bb4, ^bb5(%7 : f32)
  ^bb4:  // pred: ^bb3
    %11 = memref.load %0[%c0, %5] : memref<1x32xf32>
    memref.store %11, %1[%c0, %c0] : memref<1x4xf32>
    memref.store %c0_i32, %2[%c0, %c0] : memref<1x4xi32, strided<[4, 1], offset: 16>>
    cf.br ^bb5(%11 : f32)
  ^bb5(%12: f32):  // 2 preds: ^bb3, ^bb4
    %13 = vector.load %0[%c0, %5] : memref<1x32xf32>, vector<16xf32>
    %14 = vector.broadcast %13 : vector<16xf32> to vector<1x16xf32>
    %15 = vector.broadcast %12 : f32 to vector<16xf32>
    %16 = arith.cmpf ogt, %13, %15 : vector<16xf32>
    %17 = arith.cmpi slt, %10, %c4 : index
    %18 = arith.select %17, %cst_0, %16 : vector<16xi1>
    %19 = vector.broadcast %18 : vector<16xi1> to vector<1x16xi1>
    %20 = vector.reduction <or>, %18, %false : vector<16xi1> into i1
    %21 = vector.insertelement %20, %cst[%c0 : index] : vector<1xi1>
    %22 = vector.extract %21[0] : i1 from vector<1xi1>
    cf.cond_br %22, ^bb6(%c0, %12, %10 : index, f32, index), ^bb23(%12, %10 : f32, index)
  ^bb6(%23: index, %24: f32, %25: index):  // 2 preds: ^bb5, ^bb22
    %26 = arith.cmpi slt, %23, %c16 : index
    cf.cond_br %26, ^bb7, ^bb23(%24, %25 : f32, index)
  ^bb7:  // pred: ^bb6
    %27 = vector.extract %19[0, %23] : i1 from vector<1x16xi1>
    %28 = arith.cmpi eq, %27, %true : i1
    cf.cond_br %28, ^bb8, ^bb22(%24, %25 : f32, index)
  ^bb8:  // pred: ^bb7
    %29 = vector.extract %14[0, %23] : f32 from vector<1x16xf32>
    %30 = arith.addi %23, %5 : index
    cf.br ^bb9(%c0, %true, %c0 : index, i1, index)
  ^bb9(%31: index, %32: i1, %33: index):  // 2 preds: ^bb8, ^bb12
    %34 = arith.cmpi slt, %31, %25 : index
    cf.cond_br %34, ^bb10, ^bb13
  ^bb10:  // pred: ^bb9
    %35 = arith.cmpi eq, %32, %true : i1
    cf.cond_br %35, ^bb11, ^bb12(%32, %33 : i1, index)
  ^bb11:  // pred: ^bb10
    %36 = memref.load %1[%c1, %31] : memref<1x4xf32>
    %37 = arith.cmpf olt, %36, %29 : f32
    %38 = arith.cmpi eq, %37, %true : i1
    %39 = arith.cmpi ne, %37, %true : i1
    %40 = arith.andi %39, %32 : i1
    %41 = arith.select %38, %31, %33 : index
    cf.br ^bb12(%40, %41 : i1, index)
  ^bb12(%42: i1, %43: index):  // 2 preds: ^bb10, ^bb11
    %44 = arith.addi %31, %c1 : index
    cf.br ^bb9(%44, %42, %43 : index, i1, index)
  ^bb13:  // pred: ^bb9
    %45 = arith.cmpi eq, %25, %c4 : index
    %46 = arith.andi %45, %32 : i1
    cf.cond_br %46, ^bb22(%24, %25 : f32, index), ^bb14
  ^bb14:  // pred: ^bb13
    cf.cond_br %32, ^bb15, ^bb16
  ^bb15:  // pred: ^bb14
    memref.store %29, %1[%c0, %25] : memref<1x4xf32>
    %47 = arith.index_cast %30 : index to i32
    memref.store %47, %2[%c0, %25] : memref<1x4xi32, strided<[4, 1], offset: 16>>
    %48 = arith.addi %25, %c1 : index
    cf.br ^bb22(%29, %48 : f32, index)
  ^bb16:  // pred: ^bb14
    cf.cond_br %45, ^bb17, ^bb18(%25 : index)
  ^bb17:  // pred: ^bb16
    %49 = arith.subi %25, %c1 : index
    cf.br ^bb18(%49 : index)
  ^bb18(%50: index):  // 2 preds: ^bb16, ^bb17
    %51 = arith.subi %50, %c1 : index
    %52 = memref.load %1[%c0, %33] : memref<1x4xf32>
    %53 = memref.load %2[%c0, %33] : memref<1x4xi32, strided<[4, 1], offset: 16>>
    cf.br ^bb19(%33, %52, %53 : index, f32, i32)
  ^bb19(%54: index, %55: f32, %56: i32):  // 2 preds: ^bb18, ^bb20
    %57 = arith.cmpi slt, %54, %51 : index
    cf.cond_br %57, ^bb20, ^bb21
  ^bb20:  // pred: ^bb19
    %58 = arith.addi %54, %c1 : index
    %59 = memref.load %1[%c0, %58] : memref<1x4xf32>
    %60 = memref.load %2[%c0, %58] : memref<1x4xi32, strided<[4, 1], offset: 16>>
    memref.store %55, %1[%c0, %58] : memref<1x4xf32>
    memref.store %56, %2[%c0, %58] : memref<1x4xi32, strided<[4, 1], offset: 16>>
    cf.br ^bb19(%58, %59, %60 : index, f32, i32)
  ^bb21:  // pred: ^bb19
    %61 = memref.load %1[%c0, %50] : memref<1x4xf32>
    %62 = arith.addi %50, %c1 : index
    cf.br ^bb22(%61, %62 : f32, index)
  ^bb22(%63: f32, %64: index):  // 4 preds: ^bb7, ^bb13, ^bb15, ^bb21
    %65 = arith.addi %23, %c1 : index
    cf.br ^bb6(%65, %63, %64 : index, f32, index)
  ^bb23(%66: f32, %67: index):  // 2 preds: ^bb5, ^bb6
    %68 = arith.addi %5, %c16 : index
    cf.br ^bb2(%68, %false, %66, %67 : index, i1, f32, index)
  ^bb24:  // pred: ^bb2
    %69 = arith.addi %3, %workgroup_count_x : index
    cf.br ^bb1(%69 : index)
  ^bb25:  // pred: ^bb1
    return
  }
}
LLITCHEV commented 7 months ago

Step to reproduce:

  1. Create a file with the code above and run the following program:
  2. iree-build/tools/iree-opt --iree-convert-to-llvm ~/ToDelete/vect_to_llvm_convert.mlir
  3. The code still contains %87 = vector.extract %67[0, %83] : f32 from vector<1x16xf32>