cornell-zhang / heterocl

HeteroCL: A Multi-Paradigm Programming Infrastructure for Software-Defined Heterogeneous Computing
https://cornell-zhang.github.io/heterocl/
Apache License 2.0
322 stars 92 forks source link

`hcl.sum()` Not Using Default Data Type #426

Open zzzDavid opened 2 years ago

zzzDavid commented 2 years ago

Description

hcl.sum() ignores the default data type set by hcl.init and uses int32 when dtype is not set.

Example

A small matmul example to reproduce this issue:

import heterocl as hcl
import numpy as np

hcl.init(hcl.Float(32))

m = 2
k = 2
n = 2

matrix_1 = hcl.placeholder((m, k))
matrix_2 = hcl.placeholder((k, n))

def kernel(matrix_1, matrix_2):
    r = hcl.reduce_axis(0, k, 'k')
    return hcl.compute((m, n), lambda x, y: hcl.sum(matrix_1[x, r] * matrix_2[r, y], axis=r),name="out_matrix")

s = hcl.create_schedule([matrix_1, matrix_2], kernel)
f = hcl.build(s)

A_np = np.random.rand(m,k).astype(float)
B_np = np.random.rand(k,n).astype(float)

A = hcl.asarray(A_np)
B = hcl.asarray(B_np)
C = hcl.asarray(np.zeros((m,n)).astype(float))

f(A, B, C)

C_np = C.asnumpy()
C_ref = np.matmul(A_np, B_np)

assert np.allclose(C_np, C_ref)

If one prints C_np, the data type would still be float, but the values are all rounded to integers.