cornell-zhang / heterocl

HeteroCL: A Multi-Paradigm Programming Infrastructure for Software-Defined Heterogeneous Computing
https://cornell-zhang.github.io/heterocl/
Apache License 2.0
326 stars 92 forks source link

hcl.Struct doesn't support dtype specified as strings (e.g., 'uint16') #410

Closed jcasas00 closed 3 years ago

jcasas00 commented 3 years ago

Code:

hcl.placeholder((), "A", hcl.UInt(16))
hcl.placeholder((), "A", 'uint16')

hcl.Struct ({'foo': hcl.UInt(16) })
hcl.Struct ({'foo': 'uint16' })

Output:

>>> hcl.placeholder((), "A", hcl.UInt(16))                          # OK
<heterocl.tensor.Scalar object at 0x7f89f723f670>
>>> hcl.placeholder((), "A", 'uint16')                                # OK
<heterocl.tensor.Scalar object at 0x7f89f7155a60>
>>>
>>> hcl.Struct ({'foo': hcl.UInt(16) })                                 # OK
Struct(OrderedDict([('foo', UInt(16))]))
>>> hcl.Struct ({'foo': 'uint16' })                                        # not OK
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/heterocl/python/heterocl/types.py", line 65, in __init__
    self.bits += dtype.bits
AttributeError: 'str' object has no attribute 'bits'

hcl.Struct should support both forms of dtype specifications as other APIs do (e.g., hcl.placeholder [as above], hcl.scalar, etc.) for consistency.

jcasas00 commented 3 years ago

Similar issue occurs in TensorSlice class getattr and setattr methods.

seanlatias commented 3 years ago

Similar issue occurs in TensorSlice class getattr and setattr methods.

These two methods should be fine? Since we are accessing with keys (i.e., the name field).

jcasas00 commented 3 years ago

After I temporarily patched hcl.Struct (replacing usage of dtype.bits with get_bitwidth(dtype)) , it worked with the test case I showed above But my bigger program hit the same problem in the TensorSlice class where I had to apply the same change in the following code in both getattr and setattr: for dkey, dval in hcl_dtype.dtype_dict.items(): if dkey == key:

end = start + dval.bits # old code

            end = start + types.get_bitwidth(dval)                  # new code
            dtype = types.dtype_to_str(dval)
            break
        else:
            #start += dval.bits                                                 # old code
            start += types.get_bitwidth(dval)                          # new code
seanlatias commented 3 years ago

I see, but my fix would not generate such a problem. Will let you know when it's done.

jcasas00 commented 3 years ago

That would be a better fix then. Thanks.

seanlatias commented 3 years ago

Please check #411.

jcasas00 commented 3 years ago

Tested latest master and works for me now. Thanks.