Closed sxjscience closed 7 years ago
Currently, my implementation consists of two stages:
T_idx, T_val = COMPUTE_ARGMAX(data, name='T') real_idx = CHOOSE_INDEX(T_idx, T_val)
During the scheduling, the first stage will reuse the reduce scheduling codes (https://github.com/dmlc/tvm/blob/master/topi/python/topi/cuda/reduction.py#L7-L42). In the second stage, the T_idx, T_val will be compute_at some axis.
T_idx, T_val
compute_at
The problem is that I cannot compute_at T_idx and T_val at axes in the first stage since they are the output of the first stage.
T_idx
T_val
If I try T_idx.compute_at(FIRST_STAGE_OUT, FIRST_STAGE_AXIS), I'll receive this error:
T_idx.compute_at(FIRST_STAGE_OUT, FIRST_STAGE_AXIS)
tvm._ffi.base.TVMError: [23:43:06] D:\HKUST\tvm\src\schedule\graph.cc:179: Check failed: !visited.count(s.get()) Find loop in compute_at attach group
If I try T_idx.compute_at(SECOND_STAGE_OUT, FIRST_STAGE_AXIS), I'll receive this error:
T_idx.compute_at(SECOND_STAGE_OUT, FIRST_STAGE_AXIS)
tvm._ffi.base.TVMError: [23:55:14] D:\HKUST\tvm\src\schedule\schedule_lang.cc:133: Check failed: found Cannot find the axis iter_var(ax0.ax1.fused.outer, ) in parent's leaf_iter_vars parent=stage(argmax, 000002248AB6E980)
Is there a way to compute_at T_idx, T_val at the axis of the first stage?
We can do compute at of index the second stage (choose index), and do the spatial split on the final stage(keep the reduction related split on the compute_argmax)
Thanks a lot! I've solved this problem!
Currently, my implementation consists of two stages:
During the scheduling, the first stage will reuse the reduce scheduling codes (https://github.com/dmlc/tvm/blob/master/topi/python/topi/cuda/reduction.py#L7-L42). In the second stage, the
T_idx, T_val
will becompute_at
some axis.The problem is that I cannot compute_at
T_idx
andT_val
at axes in the first stage since they are the output of the first stage.If I try
T_idx.compute_at(FIRST_STAGE_OUT, FIRST_STAGE_AXIS)
, I'll receive this error:If I try
T_idx.compute_at(SECOND_STAGE_OUT, FIRST_STAGE_AXIS)
, I'll receive this error:Is there a way to compute_at
T_idx, T_val
at the axis of the first stage?