cornell-zhang / heterocl

HeteroCL: A Multi-Paradigm Programming Infrastructure for Software-Defined Heterogeneous Computing
https://cornell-zhang.github.io/heterocl/
Apache License 2.0
326 stars 92 forks source link

Eliminated variables cannot be found when using `reuse_at` #239

Open chhzh123 opened 4 years ago

chhzh123 commented 4 years ago

I'm trying to aggregate array values using hcl.sum and reuse some of the partial results. However, when the axis to be aggregated is of size 1, the reuse_at function causes Runtime Error when calling hcl.build. See the following program.

def test_reuse_compute():
    hcl.init()
    nz = 1
    rx = hcl.reduce_axis(0, 3, name="rx")
    rz = hcl.reduce_axis(0, nz, name="rz")
    A = hcl.placeholder((nz, 10, 10),name="A")
    B = hcl.compute((10, 8), lambda y, x: hcl.sum(A[rz, y, x+rx],axis=[rz, rx]), "B")
    s = hcl.create_schedule([A, B])
    RB = s.reuse_at(A, s[B], B.axis[1])
    print(hcl.lower(s))
    f = hcl.build(s)

It reports that rz is not found, which seems to be eliminated in the previous schedule pass.

Traceback (most recent call last):
  File "reuse.py", line 121, in <module>
    test_reuse_compute_nd()
  File "reuse.py", line 86, in test_reuse_compute
    f = hcl.build(s)
  File "/home/chz/heterocl/python/heterocl/api.py", line 318, in build
    return _build(schedule.sch, new_inputs, target=target, name=name, stmt=stmt)
  File "/home/chz/heterocl/python/heterocl/tvm/build_module.py", line 568, in build
    stmt=stmt)
  File "/home/chz/heterocl/python/heterocl/tvm/build_module.py", line 385, in lower
    return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
  File "/home/chz/heterocl/python/heterocl/tvm/_ffi/function.py", line 280, in my_api_func
    return flocal(*args)
  File "/home/chz/heterocl/python/heterocl/tvm/_ffi/_ctypes/function.py", line 183, in __call__
    ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
  File "/home/chz/heterocl/python/heterocl/tvm/_ffi/base.py", line 66, in check_call
    raise TVMError(py_str(_LIB.TVMGetLastError()))
heterocl.tvm._ffi.base.TVMError: [00:16:33] src/pass/make_api.cc:168: Not all Vars are passed in api_args:  'rz'  does not appeared in api_args

Notice that the hcl.lower function works and gives the following output.

produce B {
  // attr [A.reuse] storage_scope = "global"
  allocate A.reuse[int32 * 1 * 1 * 3]
  // attr [0] extern_scope = 0
  for (y, 0, 10) {
    for (x.reuse, 0, 10) {
      produce A.reuse {
        for (A.0, 0, 2) {
          A.reuse[A.0] = A.reuse[(A.0 + 1)]
        }
        A.reuse[2] = A[((x.reuse + (y*10)) + (rz*100))]
      }
      if ((2 <= x.reuse)) {
        // attr [sum] storage_scope = "global"
        allocate sum[int32 * 1]
        produce sum {
          // attr [0] extern_scope = 0
          sum[0] = 0
        }
        for (rx, 0, 3) {
          sum[0] = int32((int33(A.reuse[rx]) + int33(sum[0])))
        }
        B[((x.reuse + (y*8)) + -2)] = sum[0]
      }
    }
  }
}

Moreover, if I change nz to some numbers larger than 1, the program can be also built.