Open zzzDavid opened 2 years ago
hcl.sum() ignores the default data type set by hcl.init and uses int32 when dtype is not set.
hcl.sum()
hcl.init
int32
A small matmul example to reproduce this issue:
import heterocl as hcl import numpy as np hcl.init(hcl.Float(32)) m = 2 k = 2 n = 2 matrix_1 = hcl.placeholder((m, k)) matrix_2 = hcl.placeholder((k, n)) def kernel(matrix_1, matrix_2): r = hcl.reduce_axis(0, k, 'k') return hcl.compute((m, n), lambda x, y: hcl.sum(matrix_1[x, r] * matrix_2[r, y], axis=r),name="out_matrix") s = hcl.create_schedule([matrix_1, matrix_2], kernel) f = hcl.build(s) A_np = np.random.rand(m,k).astype(float) B_np = np.random.rand(k,n).astype(float) A = hcl.asarray(A_np) B = hcl.asarray(B_np) C = hcl.asarray(np.zeros((m,n)).astype(float)) f(A, B, C) C_np = C.asnumpy() C_ref = np.matmul(A_np, B_np) assert np.allclose(C_np, C_ref)
If one prints C_np, the data type would still be float, but the values are all rounded to integers.
C_np
Description
hcl.sum()
ignores the default data type set byhcl.init
and usesint32
when dtype is not set.Example
A small matmul example to reproduce this issue:
If one prints
C_np
, the data type would still be float, but the values are all rounded to integers.