Open paldebjit opened 3 years ago
Applying hcl.compute_at() changes the Tensor sizes and the data-layout.
hcl.compute_at()
Executing the following code creates out_AB[16][18], out_ABC[16][24].
import heterocl as hcl def top_2mm(P=16, Q=22, R=18, S=24, alpha=0.1, beta=0.1, dtype=hcl.Float(), target=None): hcl.init(dtype) A = hcl.placeholder((P, Q), "A") B = hcl.placeholder((Q, R), "B") C = hcl.placeholder((R, S), "C") D = hcl.placeholder((P, S), "D") def kernel_2mm(A, B, C, D): r = hcl.reduce_axis(0, Q, "r") out_AB = hcl.compute((P, R), lambda x, y: hcl.sum(A[x, r] * B[r, y], axis=r, dtype=dtype ), name="out_AB" ) k = hcl.reduce_axis(0, R, "k") out_ABC = hcl.compute((P, S), lambda x, y: hcl.sum(out_AB[x, k] * C[k, y], axis=k, dtype=dtype ), name="out_ABC" ) hcl.update(D, lambda x, y: (alpha * out_ABC[x, y] + beta * D[x, y]), name="D" ) s = hcl.create_schedule([A, B, C, D], kernel_2mm) print(hcl.build(s, target=target)) f = top_2mm(target="vhls")
However, executing the following code
def top_2mm(P=16, Q=22, R=18, S=24, alpha=0.1, beta=0.1, dtype=hcl.Float(), target=None): hcl.init(dtype) A = hcl.placeholder((P, Q), "A") B = hcl.placeholder((Q, R), "B") C = hcl.placeholder((R, S), "C") D = hcl.placeholder((P, S), "D") def kernel_2mm(A, B, C, D): r = hcl.reduce_axis(0, Q, "r") out_AB = hcl.compute((P, R), lambda x, y: hcl.sum(A[x, r] * B[r, y], axis=r, dtype=dtype ), name="out_AB" ) k = hcl.reduce_axis(0, R, "k") out_ABC = hcl.compute((P, S), lambda x, y: hcl.sum(out_AB[x, k] * C[k, y], axis=k, dtype=dtype ), name="out_ABC" ) hcl.update(D, lambda x, y: (alpha * out_ABC[x, y] + beta * D[x, y]), name="D" ) s = hcl.create_schedule([A, B, C, D], kernel_2mm) #### Apply customizations #### A = kernel_2mm.out_AB B = kernel_2mm.out_ABC D = kernel_2mm.D s[A].compute_at(s[B], B.axis[0]) s[B].compute_at(s[D], D.axis[0]) #### Apply customizations #### print(hcl.build(s, target=target)) f = top_2mm(target="vhls")
creates out_AB[1][18] and out_ABC[1][24], i.e., the Tensor size changes. Possibly this should have not been the case.
Issue statement
Applying
hcl.compute_at()
changes the Tensor sizes and the data-layout.How to reproduce
Executing the following code creates out_AB[16][18], out_ABC[16][24].
However, executing the following code
creates out_AB[1][18] and out_ABC[1][24], i.e., the Tensor size changes. Possibly this should have not been the case.