cornell-zhang / heterocl

HeteroCL: A Multi-Paradigm Programming Infrastructure for Software-Defined Heterogeneous Computing
https://cornell-zhang.github.io/heterocl/
Apache License 2.0
322 stars 92 forks source link

Empty Parenthesis in Generated SHLS Code #459

Closed zzzDavid closed 2 years ago

zzzDavid commented 2 years ago

Description

Code generated for Stratus HLS backend has empty parenthesis which fails checking in Stratus HLS.

Example

def reduce(b, w, q):
    bw  = hcl.get_bitwidth(q.dtype)
    bwh = bw // 2
    mask = (1 << bwh) - 1

    b = hcl.scalar(b, "b", dtype=hcl.UInt(bw))
    w = hcl.scalar(w, "w", dtype=hcl.UInt(bw))
    q = hcl.scalar(q, "q", dtype=hcl.UInt(bw))

    a = w * b
    for i in range(2):
        t = (-a) & mask
        s = (a + (t*q)) >> bwh
        a = s
    a = hcl.select(a < q, a, a - q)
    res = hcl.scalar(a, "reduce", dtype=hcl.UInt(bw))
    return res

# empty type parenthesis generated here...
def test2():

    def func(A,B):
        B[0] = reduce(A[0], A[1], A[2])

    A = hcl.placeholder((5,), "A", dtype=hcl.UInt(32))
    B = hcl.placeholder((2,), "B", dtype=hcl.UInt(32))
    s = hcl.create_schedule([A,B], func)

    m = hcl.build(s, "shls")

The generated SHLS code:

#include "default_function.h"
void default_function::thread1()
{
  {
    HLS_DEFINE_PROTOCOL("reset");
    B.reset();
    A.reset();
    finish.write(0);
    wait();
  }
  while( true ) 
  {
    b_x: for (sc_int<32> x = 0; x < 1; ++x) {
            b = A.get();
    }
    w_x1: for (sc_int<32> x1 = 0; x1 < 1; ++x1) {
            w = A.get();
    }
    q_x2: for (sc_int<32> x2 = 0; x2 < 1; ++x2) {
            q = A.get();
    }
    reduce_x3: for (sc_int<32> x3 = 0; x3 < 1; ++x3) {
      reduce = ((sc_uint<32>)(((((((((sc_uint<64>)w) * ((sc_uint<64>)b)) + ((((((sc_uint<64>)w) * ((sc_uint<64>)b)) * (sc_uint<64>)18446744073709551615) & 65535) * q)) >> 16) + ((((((((sc_uint<64>)w) * ((sc_uint<64>)b)) + ((((((sc_uint<64>)w) * ((sc_uint<64>)b)) * (sc_uint<64>)18446744073709551615) & 65535) * q)) >> 16) * 18446744073709551615) & 65535) * q)) >> 16) < q) ? ((((((((sc_uint<64>)w) * ((sc_uint<64>)b)) + ((((((sc_uint<64>)w) * ((sc_uint<64>)b)) * (sc_uint<64>)18446744073709551615) & 65535) * q)) >> 16) + ((((((((sc_uint<64>)w) * ((sc_uint<64>)b)) + ((((((sc_uint<64>)w) * ((sc_uint<64>)b)) * (sc_uint<64>)18446744073709551615) & 65535) * q)) >> 16) * 18446744073709551615) & 65535) * q)) >> 16)) : (()((((((((sc_uint<64>)w) * ((sc_uint<64>)b)) + ((((((sc_uint<64>)w) * ((sc_uint<64>)b)) * (sc_uint<64>)18446744073709551615) & 65535) * q)) >> 16) + ((((((((sc_uint<64>)w) * ((sc_uint<64>)b)) + ((((((sc_uint<64>)w) * ((sc_uint<64>)b)) * (sc_uint<64>)18446744073709551615) & 65535) * q)) >> 16) * 18446744073709551615) & 65535) * q)) >> 16) - q))));
    }
    B.put(reduce);
    finish.write(true);
  }
}

The empty parenthesis is here: (()((((((((sc_uint<64>)w) * ((sc_uint<64>)b))