Closed jcasas00 closed 1 year ago
Seems to be an issue with the call operation. To reproduce this issue, the IR is:
// RUN: hcl-opt --lower-print-ops --jit %s
module {
func.func @top() -> () {
%x = arith.constant 0 : i32
hcl.print(%x) {format="x: %d \n"} : i32
%y = arith.constant 1 : i32
hcl.print(%y) {format="y: %d \n"} : i32
return
}
}
And the lowered LLVM IR is:
module {
llvm.mlir.global internal constant @frmt_spec1("y: %.0f \0A")
llvm.mlir.global internal constant @frmt_spec0("x: %.0f \0A")
llvm.func @printf(!llvm.ptr<i8>, ...) -> i32
llvm.func @top() {
%0 = llvm.mlir.constant(0 : i32) : i32
%1 = llvm.mlir.addressof @frmt_spec0 : !llvm.ptr<array<9 x i8>>
%2 = llvm.mlir.constant(0 : index) : i64
%3 = llvm.getelementptr %1[%2, %2] : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
%4 = llvm.mlir.constant(0 : i64) : i64
%5 = llvm.sitofp %4 : i64 to f64
%6 = llvm.call @printf(%3, %5) : (!llvm.ptr<i8>, f64) -> i32
%7 = llvm.mlir.constant(1 : i32) : i32
%8 = llvm.mlir.addressof @frmt_spec1 : !llvm.ptr<array<9 x i8>>
%9 = llvm.mlir.constant(0 : index) : i64
%10 = llvm.getelementptr %8[%9, %9] : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
%11 = llvm.mlir.constant(1 : i64) : i64
%12 = llvm.sitofp %11 : i64 to f64
%13 = llvm.call @printf(%10, %12) : (!llvm.ptr<i8>, f64) -> i32
llvm.return
}
}
The lowered LLVM dialect IR seems correct, looking into what's causing this issue
I figured it out, it's the alignment issue of global string, e.g. llvm.mlir.global internal constant @frmt_spec1("y: %.0f \0A")
. llvm.getelementptr
would get the next global string, so there's an extra print.
Still seeing extra outputs.
def kernel():
z = hcl.scalar(3, "z", dtype=hcl.UInt(16))
hcl.print((z.v), "zz=%d ")
hcl.print((z.v,z.v), "aaaaaaaaaaaaa=%d bbbbbbbb=%d")
hcl.print((), " \n")
#
r = hcl.compute((1,), lambda _:0, dtype=hcl.UInt(32))
return r
generates the output:
zz=3 aaaaaaaaaaaaa=3 bbbbbbbb=3zz=0
Note the extra characters after the first "zz=3" and "zz=0" appears at the end. Seems like there's some overflow happening with the lines that doesn't end with a \n.
LLVM string requires a terminator \00
, the same as the \0
in C string. I encounter the same issue with reading/writing files, a more detailed description is here: https://github.com/zzzDavid/hcl-debug/tree/main/read_write_file
This issue will be closed after test cases are added.
Fixed by cornell-zhang/heterocl@90be1921a3ca94e68676a4b68a8227dff5938453
When I run this code, I get:
If either of the hcl.print calls is commented out, that 3rd line is not generated.