Closed jcasas00 closed 1 year ago
This issue is caused by MLIR's limitation on IntegerAttr. To create a constant operation, first we need to create an attribute to contain the value, this is enforced by Arith dialect's arith.constant
operation definition. An IntegerAttr supports only up to 64 bits.
However, I found that in the python binding
attr_type = IntegerType.get_signless(op.dtype.bits)
value_attr = IntegerAttr.get(attr_type, value)
value
can only be up to 0xfff_ffff_ffff_ffff
, that means this Integer.get
API can only support up to 63 bits.
For 64-bit integer, this worked:
value_attr = IntegerAttr.parse(str(value))
I have added a case for 64-bit integer, and we throw an MLIR limitation error for integer wider than 64-bits.
The above test case should work after pull from the front-end repository.
The following test case has been added to our test suite, closing the thread as the issue is resolved
def test_gather128():
hcl.init()
def kernel():
a32 = hcl.compute((4,), lambda i: i+50, "a1", dtype='uint32')
v = hcl.scalar(0, "x", dtype='uint128')
factor = 128//32
def shift_copy(i):
v.v = 0
for j in range(factor): # j = 0, 1, 2, 3
a = a32[i*factor + j] # a = a32[0], a32[1], a32[2], a32[3]
v.v = (v.v << 32) | a
hcl.mutate((1,), shift_copy)
res = hcl.compute((4,), lambda i : 0, "res", dtype='uint32')
res[0] = (v.v >> 0) & 0xFFFFFFFF # should be a32[3]
res[1] = (v.v >> 32) & 0xFFFFFFFF # should be a32[2]
res[2] = (v.v >> 64) & 0xFFFFFFFF # should be a32[1]
res[3] = (v.v >> 96) & 0xFFFFFFFF # should be a32[0]
return res
s = hcl.create_schedule([], kernel)
hcl_res = hcl.asarray(np.zeros((4,), dtype=np.uint32), dtype=hcl.UInt(32))
f = hcl.build(s)
f(hcl_res)
golden = np.array([53, 52, 51, 50], dtype=np.uint32)
assert np.allclose(hcl_res.asnumpy(), golden)
def test_mask64(): hcl.init() def kernel(): v = hcl.scalar(0, "v", hcl.UInt(64)) v.v = v.v & 0xffff_ffff_ffff_ffff # r = hcl.compute((4,), lambda i: 0, dtype=hcl.UInt(32)) return r s = hcl.create_schedule([], kernel) print(hcl.lower(s)) hcl_res = hcl.asarray(np.zeros((4,), dtype=np.uint32), dtype=hcl.UInt(32)) f = hcl.build(s) f(hcl_res)
generates:
TypeError: get(): incompatible function arguments. The following argument types are supported:
Invoked with: Type(i64), 18446744073709551615