cornell-zhang / allo

Allo: A Programming Model for Composable Accelerator Design
https://cornell-zhang.github.io/allo
Apache License 2.0
122 stars 14 forks source link

Large Bit Integers #168

Open Arsalan-Zahid opened 1 month ago

Arsalan-Zahid commented 1 month ago

Hello, I'm working with Debjit's research team. We were interested in doing is adding large bit integers to Allo, but we were having some difficulties. I saw that the AlloType class inside types.py prohibited integers > 2047 in size, and that made me wonder: is this a safeguard for the MLIR code? Is there some limitation within MLIR on integers larger than 2047 bits?

To try and get around it, I first tried to define my own class [1] and pass around the typing rules, but I was unfortunately caught by typing_rule.py:123 because that uses type() to get the class, and I don't know if it's possible to get around that.

Are we headed in the right direction? Is this compatible with the MLIR? What is another approach we could take, if any?

@paldebjit @chhzh123 @zzzDavid

[1]


class Int(AlloType):
    def __init__(self, bits):
        super().__init__(bits, 0, f"i{bits}")
        self.name = f"i{bits}"
        print(self.name)
    def __repr__(self):
        return self.name

    def __hash__(self):
        return self.name

    def build(self):
        return IntegerType.get_signless(self.bits)

    @staticmethod
    def isinstance(other):
        #return isinstance(other, (Int, int))
        return isinstance(other, (alloInt, int))
        #alloInt is from allo.ir.types import Int as alloInt
chhzh123 commented 1 month ago

Hi @Arsalan-Zahid, thanks for your interest in our project. You do not need to define a new class to support larger bitwidths but can simply remove this guard. For example, by running the following code, you can generate the corresponding MLIR module.

def test_large_bitwidth():
    def kernel(a: Int(65536), b: Int(345)) -> Int(65536):
        return a + b

    s = allo.customize(kernel)
    print(s.module)

It will print out

module {
  func.func @kernel(%arg0: i65536, %arg1: i345) -> i65536 attributes {itypes = "ss", otypes = "s"} {
    %0 = arith.extsi %arg0 : i65536 to i65537
    %1 = arith.extsi %arg1 : i345 to i65537
    %2 = arith.addi %0, %1 : i65537
    %3 = arith.trunci %2 : i65537 to i65536
    return %3 : i65536
  }
}

However, you cannot pass in data into it to verify the correctness using the LLVM backend, as it highly depends on the C data types and the NumPy package that cannot natively support such large bitwidth integers.