Open LLITCHEV opened 7 months ago
The VectorConvertToLLVM pass is not converting the %85 = vector.extract %75[0, %83] : i1 from vector<1x16xi1> to LLVM dialect. It appers the conversion code converts only 1-d types.
module { func.func @custom_call_topk_tuple_16_dispatch_0_topk_1x32xf32() { %cst = arith.constant dense<false> : vector<1xi1> %cst_0 = arith.constant dense<true> : vector<16xi1> %c0_i32 = arith.constant 0 : i32 %false = arith.constant false %c4 = arith.constant 4 : index %cst_1 = arith.constant 0.000000e+00 : f32 %true = arith.constant true %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c64 = arith.constant 64 : index %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<1x32xf32> memref.assume_alignment %0, 64 : memref<1x32xf32> %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<1x4xf32> memref.assume_alignment %1, 64 : memref<1x4xf32> %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c64) : memref<1x4xi32, strided<[4, 1], offset: 16>> memref.assume_alignment %2, 64 : memref<1x4xi32, strided<[4, 1], offset: 16>> %workgroup_id_x = hal.interface.workgroup.id[0] : index %workgroup_count_x = hal.interface.workgroup.count[0] : index cf.br ^bb1(%workgroup_id_x : index) ^bb1(%3: index): // 2 preds: ^bb0, ^bb24 %4 = arith.cmpi slt, %3, %c1 : index cf.cond_br %4, ^bb2(%c0, %true, %cst_1, %c0 : index, i1, f32, index), ^bb25 ^bb2(%5: index, %6: i1, %7: f32, %8: index): // 2 preds: ^bb1, ^bb23 %9 = arith.cmpi slt, %5, %c32 : index cf.cond_br %9, ^bb3, ^bb24 ^bb3: // pred: ^bb2 %10 = arith.select %6, %c1, %8 : index cf.cond_br %6, ^bb4, ^bb5(%7 : f32) ^bb4: // pred: ^bb3 %11 = memref.load %0[%c0, %5] : memref<1x32xf32> memref.store %11, %1[%c0, %c0] : memref<1x4xf32> memref.store %c0_i32, %2[%c0, %c0] : memref<1x4xi32, strided<[4, 1], offset: 16>> cf.br ^bb5(%11 : f32) ^bb5(%12: f32): // 2 preds: ^bb3, ^bb4 %13 = vector.load %0[%c0, %5] : memref<1x32xf32>, vector<16xf32> %14 = vector.broadcast %13 : vector<16xf32> to vector<1x16xf32> %15 = vector.broadcast %12 : f32 to vector<16xf32> %16 = arith.cmpf ogt, %13, %15 : vector<16xf32> %17 = arith.cmpi slt, %10, %c4 : index %18 = arith.select %17, %cst_0, %16 : vector<16xi1> %19 = vector.broadcast %18 : vector<16xi1> to vector<1x16xi1> %20 = vector.reduction <or>, %18, %false : vector<16xi1> into i1 %21 = vector.insertelement %20, %cst[%c0 : index] : vector<1xi1> %22 = vector.extract %21[0] : i1 from vector<1xi1> cf.cond_br %22, ^bb6(%c0, %12, %10 : index, f32, index), ^bb23(%12, %10 : f32, index) ^bb6(%23: index, %24: f32, %25: index): // 2 preds: ^bb5, ^bb22 %26 = arith.cmpi slt, %23, %c16 : index cf.cond_br %26, ^bb7, ^bb23(%24, %25 : f32, index) ^bb7: // pred: ^bb6 %27 = vector.extract %19[0, %23] : i1 from vector<1x16xi1> %28 = arith.cmpi eq, %27, %true : i1 cf.cond_br %28, ^bb8, ^bb22(%24, %25 : f32, index) ^bb8: // pred: ^bb7 %29 = vector.extract %14[0, %23] : f32 from vector<1x16xf32> %30 = arith.addi %23, %5 : index cf.br ^bb9(%c0, %true, %c0 : index, i1, index) ^bb9(%31: index, %32: i1, %33: index): // 2 preds: ^bb8, ^bb12 %34 = arith.cmpi slt, %31, %25 : index cf.cond_br %34, ^bb10, ^bb13 ^bb10: // pred: ^bb9 %35 = arith.cmpi eq, %32, %true : i1 cf.cond_br %35, ^bb11, ^bb12(%32, %33 : i1, index) ^bb11: // pred: ^bb10 %36 = memref.load %1[%c1, %31] : memref<1x4xf32> %37 = arith.cmpf olt, %36, %29 : f32 %38 = arith.cmpi eq, %37, %true : i1 %39 = arith.cmpi ne, %37, %true : i1 %40 = arith.andi %39, %32 : i1 %41 = arith.select %38, %31, %33 : index cf.br ^bb12(%40, %41 : i1, index) ^bb12(%42: i1, %43: index): // 2 preds: ^bb10, ^bb11 %44 = arith.addi %31, %c1 : index cf.br ^bb9(%44, %42, %43 : index, i1, index) ^bb13: // pred: ^bb9 %45 = arith.cmpi eq, %25, %c4 : index %46 = arith.andi %45, %32 : i1 cf.cond_br %46, ^bb22(%24, %25 : f32, index), ^bb14 ^bb14: // pred: ^bb13 cf.cond_br %32, ^bb15, ^bb16 ^bb15: // pred: ^bb14 memref.store %29, %1[%c0, %25] : memref<1x4xf32> %47 = arith.index_cast %30 : index to i32 memref.store %47, %2[%c0, %25] : memref<1x4xi32, strided<[4, 1], offset: 16>> %48 = arith.addi %25, %c1 : index cf.br ^bb22(%29, %48 : f32, index) ^bb16: // pred: ^bb14 cf.cond_br %45, ^bb17, ^bb18(%25 : index) ^bb17: // pred: ^bb16 %49 = arith.subi %25, %c1 : index cf.br ^bb18(%49 : index) ^bb18(%50: index): // 2 preds: ^bb16, ^bb17 %51 = arith.subi %50, %c1 : index %52 = memref.load %1[%c0, %33] : memref<1x4xf32> %53 = memref.load %2[%c0, %33] : memref<1x4xi32, strided<[4, 1], offset: 16>> cf.br ^bb19(%33, %52, %53 : index, f32, i32) ^bb19(%54: index, %55: f32, %56: i32): // 2 preds: ^bb18, ^bb20 %57 = arith.cmpi slt, %54, %51 : index cf.cond_br %57, ^bb20, ^bb21 ^bb20: // pred: ^bb19 %58 = arith.addi %54, %c1 : index %59 = memref.load %1[%c0, %58] : memref<1x4xf32> %60 = memref.load %2[%c0, %58] : memref<1x4xi32, strided<[4, 1], offset: 16>> memref.store %55, %1[%c0, %58] : memref<1x4xf32> memref.store %56, %2[%c0, %58] : memref<1x4xi32, strided<[4, 1], offset: 16>> cf.br ^bb19(%58, %59, %60 : index, f32, i32) ^bb21: // pred: ^bb19 %61 = memref.load %1[%c0, %50] : memref<1x4xf32> %62 = arith.addi %50, %c1 : index cf.br ^bb22(%61, %62 : f32, index) ^bb22(%63: f32, %64: index): // 4 preds: ^bb7, ^bb13, ^bb15, ^bb21 %65 = arith.addi %23, %c1 : index cf.br ^bb6(%65, %63, %64 : index, f32, index) ^bb23(%66: f32, %67: index): // 2 preds: ^bb5, ^bb6 %68 = arith.addi %5, %c16 : index cf.br ^bb2(%68, %false, %66, %67 : index, i1, f32, index) ^bb24: // pred: ^bb2 %69 = arith.addi %3, %workgroup_count_x : index cf.br ^bb1(%69 : index) ^bb25: // pred: ^bb1 return } }
Step to reproduce:
The VectorConvertToLLVM pass is not converting the %85 = vector.extract %75[0, %83] : i1 from vector<1x16xi1> to LLVM dialect. It appers the conversion code converts only 1-d types.