xdslproject / xdsl

A Python Compiler Design Toolkit
Other
225 stars 62 forks source link

Type Conversion of function arguments #2857

Closed francescodaghero closed 1 week ago

francescodaghero commented 1 week ago

I am unsure if this is intended behavior or I am doing something incorrectly, but when I try using a TypeConversionPattern on the arguments of a FuncOp, the argument types do not change even with the recursive flag set to True.

An example:

from xdsl.dialects.func import Func
from xdsl.dialects.builtin import (
    Builtin,
    IndexType,
    IntegerType,
)
from xdsl.parser import Parser
from xdsl.pattern_rewriter import (
    PatternRewriteWalker,
    TypeConversionPattern,
    attr_type_rewrite_pattern,
)
from xdsl.context import MLContext

prog = """\
"builtin.module"() ({
  "func.func"() <{function_type = (memref<2x4xui16>, memref<2x3xui8>) -> (), sym_name = "main", sym_visibility = "private"}> ({
  ^bb0(%arg0: memref<2x4xui16>, %arg1: memref<2x3xui8>):
    "func.return"() : () -> ()
  }) : () -> ()
}) : () -> ()
"""

ctx = MLContext()
ctx.load_dialect(Builtin)
ctx.load_dialect(Func)
parser = Parser(ctx, prog)
module = parser.parse_module()

class Rewrite(TypeConversionPattern):
    @attr_type_rewrite_pattern
    def convert_type(self, typ: IntegerType) -> IndexType:
        return IndexType()

PatternRewriteWalker(Rewrite(recursive=True), apply_recursively=True).rewrite_module(
    module
)

The output of the above code snippet is:

builtin.module {
  func.func private @main(%arg0 : memref<2x4xui16>, %arg1 : memref<2x3xui8>) {
    func.return
  }
}

while I expected:

builtin.module {
  func.func private @main(%arg0 : memref<2x4xindex>, %arg1 : memref<2x3xindex>) {
    func.return
  }
}
superlopuh commented 1 week ago

CC @PapyChacal