cornell-zhang / heterocl

HeteroCL: A Multi-Paradigm Programming Infrastructure for Software-Defined Heterogeneous Computing
https://cornell-zhang.github.io/heterocl/
Apache License 2.0
322 stars 92 forks source link

`hcl.compute_at()` changing tensor sizes and data-layout #390

Open paldebjit opened 3 years ago

paldebjit commented 3 years ago

Issue statement

Applying hcl.compute_at() changes the Tensor sizes and the data-layout.

How to reproduce

Executing the following code creates out_AB[16][18], out_ABC[16][24].

import heterocl as hcl

def top_2mm(P=16, Q=22, R=18, S=24, alpha=0.1, beta=0.1, dtype=hcl.Float(), target=None):

    hcl.init(dtype)
    A = hcl.placeholder((P, Q), "A")
    B = hcl.placeholder((Q, R), "B")
    C = hcl.placeholder((R, S), "C")
    D = hcl.placeholder((P, S), "D")

    def kernel_2mm(A, B, C, D):

        r = hcl.reduce_axis(0, Q, "r")
        out_AB = hcl.compute((P, R), 
                         lambda x, y: hcl.sum(A[x, r] * B[r, y], 
                         axis=r, 
                         dtype=dtype
                         ), 
                         name="out_AB"
                         )

        k = hcl.reduce_axis(0, R, "k")
        out_ABC = hcl.compute((P, S), 
                         lambda x, y: hcl.sum(out_AB[x, k] * C[k, y], 
                         axis=k, 
                         dtype=dtype
                         ), 
                         name="out_ABC"
                         )
        hcl.update(D,
                   lambda x, y: (alpha * out_ABC[x, y] + beta * D[x, y]),
                   name="D"
                   )

    s = hcl.create_schedule([A, B, C, D], kernel_2mm)

    print(hcl.build(s, target=target))

f = top_2mm(target="vhls")

However, executing the following code


def top_2mm(P=16, Q=22, R=18, S=24, alpha=0.1, beta=0.1, dtype=hcl.Float(), target=None):

    hcl.init(dtype)
    A = hcl.placeholder((P, Q), "A")
    B = hcl.placeholder((Q, R), "B")
    C = hcl.placeholder((R, S), "C")
    D = hcl.placeholder((P, S), "D")

    def kernel_2mm(A, B, C, D):

        r = hcl.reduce_axis(0, Q, "r")
        out_AB = hcl.compute((P, R), 
                         lambda x, y: hcl.sum(A[x, r] * B[r, y], 
                         axis=r, 
                         dtype=dtype
                         ), 
                         name="out_AB"
                         )

        k = hcl.reduce_axis(0, R, "k")
        out_ABC = hcl.compute((P, S), 
                         lambda x, y: hcl.sum(out_AB[x, k] * C[k, y], 
                         axis=k, 
                         dtype=dtype
                         ), 
                         name="out_ABC"
                         )
        hcl.update(D,
                   lambda x, y: (alpha * out_ABC[x, y] + beta * D[x, y]),
                   name="D"
                   )

    s = hcl.create_schedule([A, B, C, D], kernel_2mm)

    #### Apply customizations ####

    A = kernel_2mm.out_AB
    B = kernel_2mm.out_ABC
    D = kernel_2mm.D

    s[A].compute_at(s[B], B.axis[0])
    s[B].compute_at(s[D], D.axis[0])

    #### Apply customizations ####

    print(hcl.build(s, target=target))

f = top_2mm(target="vhls")

creates out_AB[1][18] and out_ABC[1][24], i.e., the Tensor size changes. Possibly this should have not been the case.