cornell-zhang / heterocl

HeteroCL: A Multi-Paradigm Programming Infrastructure for Software-Defined Heterogeneous Computing
https://cornell-zhang.github.io/heterocl/
Apache License 2.0
326 stars 92 forks source link

Incorrectly removed cast node #386

Closed zzzDavid closed 2 years ago

zzzDavid commented 3 years ago

Minimum test case

import heterocl as hcl

def test():
    A = hcl.placeholder((10,10), dtype=hcl.UInt(16), name="A")
    B = hcl.placeholder((10,10), dtype=hcl.Int(16), name="B")

    def algo(A, B):
        def f_mutate(i,j):
            factor = hcl.scalar(B[0][0][13:11], name="factor")
            idx = hcl.scalar(B[0][0][11:0], dtype=hcl.UInt(16), name="idx")
            idx += i * hcl.cast(hcl.UInt(16), factor.v) # this cast was incorrectly removed
            A[idx][j] = B[idx][j]
        bound = hcl.scalar(5, dtype=hcl.Int(32))
        domain = (hcl.cast(hcl.UInt(32), bound.v), hcl.cast(hcl.UInt(32), bound.v))
        hcl.mutate(domain, f_mutate)

    s = hcl.create_schedule([A, B], algo)
    f = hcl.build(s, target="vhls")

Error stack

heterocl.tvm._ffi.base.TVMError: [14:22:09] HalideIR/src/ir/IR.h:152: Check failed: a.type() == b.type() BinaryOp of mismatched types

Stack trace returned 10 entries:
[bt] (0) 0   libhcl.dylib                        0x000000011c1d0bde dmlc::StackTrace() + 254
[bt] (1) 1   libhcl.dylib                        0x000000011c1d098f dmlc::LogMessageFatal::~LogMessageFatal() + 47
[bt] (2) 2   libhcl.dylib                        0x000000011c1e0a2b Halide::Internal::Mul::make(Halide::Expr, Halide::Expr) + 1195
[bt] (3) 3   libhcl.dylib                        0x000000011c389861 TVM::ir::IRMutator::Mutate_(Halide::Internal::Mul const*, Halide::Expr const&) + 305
[bt] (4) 4   libhcl.dylib                        0x000000011c3a09c8 std::__1::__function::__func<TVM::ir::$_32, std::__1::allocator<TVM::ir::$_32>, Halide::Ex
pr (Halide::Internal::Mul const*, Halide::Expr const&, TVM::ir::IRMutator*)>::operator()(Halide::Internal::Mul const*&&, Halide::Expr const&, TVM::ir::IRMutat
or*&&) + 24
[bt] (5) 5   libhcl.dylib                        0x000000011c395dc1 std::__1::__function::__func<TVM::IRFunctor<Halide::Internal::Stmt (TVM::NodeRef const&, H
alide::Internal::Stmt const&, TVM::ir::IRMutator*)>& TVM::IRFunctor<Halide::Internal::Stmt (TVM::NodeRef const&, Halide::Internal::Stmt const&, TVM::ir::IRMut
ator*)>::set_dispatch<Halide::Internal::AttrStmt>(std::__1::function<Halide::Internal::Stmt (Halide::Internal::AttrStmt const*, Halide::Internal::Stmt const&,
 TVM::ir::IRMutator*)>)::'lambda'(TVM::NodeRef const&, Halide::Internal::Stmt const&, TVM::ir::IRMutator*), std::__1::allocator<TVM::IRFunctor<Halide::Interna
l::Stmt (TVM::NodeRef const&, Halide::Internal::Stmt const&, TVM::ir::IRMutator*)>& TVM::IRFunctor<Halide::Internal::Stmt (TVM::NodeRef const&, Halide::Intern
al::Stmt const&, TVM::ir::IRMutator*)>::set_dispatch<Halide::Internal::AttrStmt>(std::__1::function<Halide::Internal::Stmt (Halide::Internal::AttrStmt const*,
 Halide::Internal::Stmt const&, TVM::ir::IRMutator*)>)::'lambda'(TVM::NodeRef const&, Halide::Internal::Stmt const&, TVM::ir::IRMutator*)>, Halide::Internal::
Stmt (TVM::NodeRef const&, Halide::Internal::Stmt const&, TVM::ir::IRMutator*)>::operator()(TVM::NodeRef const&, Halide::Internal::Stmt const&, TVM::ir::IRMut
ator*&&) + 49
[bt] (6) 6   libhcl.dylib                        0x000000011c22060c TVM::IRFunctor<Halide::Expr (TVM::NodeRef const&, Halide::Expr const&, TVM::ir::IRMutator*
)>::operator()(TVM::NodeRef const&, Halide::Expr const&, TVM::ir::IRMutator*) const + 348
[bt] (7) 7   libhcl.dylib                        0x000000011c318abb TVM::ir::IRMutator::Mutate(Halide::Expr) + 59
[bt] (8) 8   libhcl.dylib                        0x000000011c38c91f TVM::ir::IRMutator::Mutate_(Halide::Internal::Cast const*, Halide::Expr const&) + 63
[bt] (9) 9   libhcl.dylib                        0x000000011c3a5188 std::__1::__function::__func<TVM::ir::$_46, std::__1::allocator<TVM::ir::$_46>, Halide::Ex
pr (Halide::Internal::Cast const*, Halide::Expr const&, TVM::ir::IRMutator*)>::operator()(Halide::Internal::Cast const*&&, Halide::Expr const&, TVM::ir::IRMut
ator*&&) + 24
zzzDavid commented 3 years ago

Added corresponding test case and this issue will be fixed by pull request #375