tensorflow / mlir

"Multi-Level Intermediate Representation" Compiler Infrastructure
1.74k stars 259 forks source link

Expanding Op into multiple Ops during Lowering broke IR? #188

Closed Naville closed 5 years ago

Naville commented 5 years ago

tldr: I need to replace the old matched Op with two chaining ops, the first op's return type should be the same as the return type of the old Op


Hi: In our Dialect we need to lower a Op into multiple target specific Ops. Below is related code:

class OpConverter : public ConversionPattern {
public:
  OpConverter(MLIRContext *ctx)
      : ConversionPattern("Op", 1, ctx) {}
  PatternMatchResult
  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
                  ConversionPatternRewriter &rewriter) const override {
    SourceOp sop=cast<SourceOp>(op);
     TargetOp top=rewriter.create<TargetOp>(op->getLoc(),........);
    rewriter.replaceOpWithNewOp<SomeCastOp>(sop,.....);
    return matchSuccess();
  }
};

The input is the following:

...
func @XXXX(%arg0:vector<3 x i32>,%arg1:i1){
  %arg = "SOURCEOP"() {A="B"}:() -> index
   %ret = index_cast %arg: index to i32
  return 
}
....

I expect the following MLIR to be generated:

func @XXXX(%arg0:vector<3 x i32>,%arg1:i1){
  %arg = "TARGETOP"() {A="B"}:() -> i32
  %0 = index_cast %arg i32 to index
   %ret = index_cast %0: index to i32
  return 
}

Instead I got:

func @XXXX(%arg0:vector<3 x i32>,%arg1:i1){
  %arg = "TARGETOP"() {A="B"}:() -> i32
  %0 = index_cast %arg i32 to index
   %ret = index_cast %arg0: vector<3 x i32> to i32
  return 
}
Naville commented 5 years ago

I've tried using the DDR framework for matching and lowering, with a pattern like the following:

def XXXX: Pat<(SOURCEOP:$op.....),
(IndexCastOp (TARGETOP))>;

where target op is defined as the following:

def XXX_YYYYOp:XXXXOp<"XXXX",[NoSideEffect]>{
    let arguments = (ins);
    let results = (outs
        AnyTypeOf<[I16,I32]>:$value
    );
}

However as far as I can tell DDR doesn't allow passing the result of matched Op around, as such I can't ::build YYYYOp due to the last argument, Type is missing. Adding a new builder with types hard-coded in only serves as a workaround though:

    let builders = [OpBuilder<
      "Builder *builder, OperationState &result", [{
            result.addTypes(builder->getIntegerType(32));
    }]>];