llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
29.08k stars 11.99k forks source link

[MLIR][PDL] Invalid value from native rewrite #69888

Open qedawkins opened 1 year ago

qedawkins commented 1 year ago

Problem Description

I have a native rewrite function registered with PDL that does an element-by-element rewrite of Values in a ValueRange that is seemingly failing to populate the PDL ByteCode result list with the new value. My rewrite function looks like

static ValueRange getI32TensorSizes(PatternRewriter &rewriter,
                                               ValueRange vals) {
  SmallVector<Value> flatI32TensorSizes;
  for (auto val : vals) {
    if (isa<IndexType>(val.getType())) {
      flatI32TensorSizes.push_back(rewriter.create<arith::IndexCastOp>(
          val.getLoc(), rewriter.getIntegerType(32), val));
    }   
  }
  return ValueRange(flatI32TensorSizes);
}
  patterns.getPDLPatterns().registerRewriteFunction(
                               "convert_index_to_i32",
                               getI32TensorSizes);

and I have pdl IR like

%workload = pdl.apply_native_rewrite "get_tensor_sizes"(%range : !pdl.range<value>) : !pdl.range<value>
%new_dims = pdl.apply_native_rewrite "convert_index_to_i32"(%workload : !pdl.range<value>) : !pdl.range<value>

If I run my interpreter pass with --debug-only=pdl-bytecode, it successfully prints the arguments of to the native rewrite, but fails to print the results.

loc("/home/quinn/SHARK-Runtime/samples/custom_dispatch/vulkan/shaders/pattern_module.mlir":62:19)
Executing ApplyRewrite:
  * Arguments: %dim = tensor.dim %1, %c0_0 : tensor<?xf32>  * Result: Please report issues to https://github.com/openxla/iree/issues and include the crash backtrace.

(the full stack backtrace can be found here: https://gist.github.com/qedawkins/2f01e231caa8933c8c75c0b1a83b4d65, crashing here on this debug line: https://github.com/llvm/llvm-project/blob/a9136f0ad94bf7738c585c6d12ad5bbe1815f95b/mlir/lib/Rewrite/ByteCode.cpp#L1450)

If I print the surrounding IR immediately before returning from the native rewrite function, I see that the index_cast I wanted to insert is there

func.func @mixed_invocation(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
  %c0 = arith.constant 0 : index
  %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index
  %1 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<?xf32>{%0}
  %2 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index
  %3 = hal.tensor.import %arg1 "input 1" : !hal.buffer_view -> tensor<?xf32>{%2}
  %c0_0 = arith.constant 0 : index
  %dim = tensor.dim %1, %c0_0 : tensor<?xf32>
  %4 = arith.index_cast %dim : index to i32
  %5 = arith.mulf %1, %3 : tensor<?xf32>
  %6 = arith.addf %5, %3 : tensor<?xf32>
  %dim_1 = tensor.dim %6, %c0 : tensor<?xf32>
  %7 = hal.tensor.export %6 "output 0" : tensor<?xf32>{%dim_1} -> !hal.buffer_view
  return %7 : !hal.buffer_view
}
%4 = arith.index_cast %dim : index to i32

cc @MaheshRavishankar

llvmbot commented 1 year ago

@llvm/issue-subscribers-bug

Author: Quinn Dawkins (qedawkins)

## Problem Description I have a native rewrite function registered with PDL that does an element-by-element rewrite of Values in a ValueRange that is seemingly failing to populate the PDL ByteCode result list with the new value. My rewrite function looks like ``` static ValueRange getI32TensorSizes(PatternRewriter &rewriter, ValueRange vals) { SmallVector<Value> flatI32TensorSizes; for (auto val : vals) { if (isa<IndexType>(val.getType())) { flatI32TensorSizes.push_back(rewriter.create<arith::IndexCastOp>( val.getLoc(), rewriter.getIntegerType(32), val)); } } return ValueRange(flatI32TensorSizes); } ``` ``` patterns.getPDLPatterns().registerRewriteFunction( "convert_index_to_i32", getI32TensorSizes); ``` and I have pdl IR like ``` %workload = pdl.apply_native_rewrite "get_tensor_sizes"(%range : !pdl.range<value>) : !pdl.range<value> %new_dims = pdl.apply_native_rewrite "convert_index_to_i32"(%workload : !pdl.range<value>) : !pdl.range<value> ``` If I run my interpreter pass with `--debug-only=pdl-bytecode`, it successfully prints the arguments of to the native rewrite, but fails to print the results. ``` loc("/home/quinn/SHARK-Runtime/samples/custom_dispatch/vulkan/shaders/pattern_module.mlir":62:19) Executing ApplyRewrite: * Arguments: %dim = tensor.dim %1, %c0_0 : tensor<?xf32> * Result: Please report issues to https://github.com/openxla/iree/issues and include the crash backtrace. ``` (the full stack backtrace can be found here: https://gist.github.com/qedawkins/2f01e231caa8933c8c75c0b1a83b4d65, crashing here on this debug line: https://github.com/llvm/llvm-project/blob/a9136f0ad94bf7738c585c6d12ad5bbe1815f95b/mlir/lib/Rewrite/ByteCode.cpp#L1450) If I print the surrounding IR immediately before returning from the native rewrite function, I see that the index_cast I wanted to insert is there ``` func.func @mixed_invocation(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} { %c0 = arith.constant 0 : index %0 = hal.buffer_view.dim<%arg0 : !hal.buffer_view>[0] : index %1 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<?xf32>{%0} %2 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index %3 = hal.tensor.import %arg1 "input 1" : !hal.buffer_view -> tensor<?xf32>{%2} %c0_0 = arith.constant 0 : index %dim = tensor.dim %1, %c0_0 : tensor<?xf32> %4 = arith.index_cast %dim : index to i32 %5 = arith.mulf %1, %3 : tensor<?xf32> %6 = arith.addf %5, %3 : tensor<?xf32> %dim_1 = tensor.dim %6, %c0 : tensor<?xf32> %7 = hal.tensor.export %6 "output 0" : tensor<?xf32>{%dim_1} -> !hal.buffer_view return %7 : !hal.buffer_view } ``` ``` %4 = arith.index_cast %dim : index to i32 ``` cc @MaheshRavishankar
qedawkins commented 1 year ago

Using a SmallVector<Value> and let the result processing do the conversion for me works fine

static SmallVector<Value> getI32TensorSizes(PatternRewriter &rewriter,                                   
                                               ValueRange vals) {                                        
  SmallVector<Value> flatI32TensorSizes;                                                                 
  for (auto val : vals) {                                                                                
    if (isa<IndexType>(val.getType())) {                                                                 
      flatI32TensorSizes.push_back(rewriter.create<arith::IndexCastOp>(                                  
          val.getLoc(), rewriter.getIntegerType(32), val).getResult());                                  
    }                                                                                                    
  }                                                                                                                                       
  return flatI32TensorSizes;                                                                                                                                  
}