apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.58k stars 3.43k forks source link

[Bug] Crash if we use the same variable in LetStmt and Select Expr #9407

Open syang-ng opened 2 years ago

syang-ng commented 2 years ago

It seems that if we use the same variable v in both let statement and select expression like this, it will trigger crash:

primfn(v: int32) -> () {
  let v = select(cast(bool, v), 150, 1)
  v
}

Expected behavior

Throw exception or just build successfully.

Actual behavior

The program crashes. And I use gdb to trace this bug, it seems that op is nullptr in arith/ir_mutator_with_analyzer.cc.

image

image

Environment

Ubuntu18.04, cmake 3.18.2, llvm 12, tested on latest git 4087e72

Steps to reproduce

Here is the

import tvm
from tvm import tir

v = tir.Var('v', 'int32')
s = tir.Select(tir.Cast('bool',v), 150, 1)
let_stmt = tir.LetStmt(v, s, tir.Evaluate(v))
f = tir.PrimFunc({v},let_stmt)
tvm.build(f)
syang-ng commented 2 years ago

Similarly, the following example can also trigger a crash.

import tvm
from tvm import tir

var = tir.Var('var', 'int32')
false_value = tir.Cast('int32', tir.acosh(tir.Cast('float32', var)))
value = tir.Select(var > 0, tir.const(0), false_value)
let_stmt = tir.LetStmt(var, value, tir.Evaluate(var))
f = tir.PrimFunc({}, let_stmt)
tvm.build(f)
primfn() -> () {
  let var: int32 = select((var > 0), 0, cast(int32, @tir.acosh(cast(float32, var), dtype=float32)))
  var
}

Different from the last example, this bug crashes in PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op)

image

wrongtest-intellif commented 2 years ago

Interesting findings! Wonder why it would not lead to an unfinished recursion error... Generally, I think non-ssa formed tir is not safe to be built.

syang-ng commented 2 years ago

Thanks for your reply! Theoretically, it should be an unfinished recursion error, while the backtrace tells me it not. I speculate that another error may be triggered during the recursion. image

syang-ng commented 2 years ago

I think another crash is triggered because of the same reason. Here is the example code to trigger this crash:

import tvm

var = tvm.tir.Var('var', dtype='int32')
loop_var = tvm.tir.Var('loop_var', dtype='int32')
min_val = tvm.tir.floormod(loop_var+1, tvm.tir.const(3))
none = tvm.tir.Evaluate(tvm.tir.const(0))

seq = tvm.tir.SeqStmt([
    tvm.tir.LetStmt(var, tvm.tir.const(1), body=none),
    tvm.tir.For(loop_var=loop_var, min_val=min_val, extent=var, kind=tvm.tir.ForKind.SERIAL, body=none)
])

func = tvm.tir.PrimFunc(params=[var], body=seq)
print(func)
tvm.build(func)