apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.42k stars 3.4k forks source link

[Arith][SVE] Add rewrite rules for indices split by scalable expressions #17046

Closed Anndrey24 closed 1 month ago

Anndrey24 commented 1 month ago

This commit introduces rewrite rules for indices which can arise from splitting axes by scalable factors (e.g. xo, xi = sch.split(x, factors = [None, 8 * T.vscale()])):

(v_x_o * T.Cast("int64", T.vscale()) * T.int64(8) + v_x_i) // (T.Cast("int64", T.vscale()) * T.int64(8)) == v_x_o
(v_x_o * T.Cast("int64", T.vscale()) * T.int64(8) + v_x_i) % (T.Cast("int64", T.vscale()) * T.int64(8)) == v_x_i

The rewrites help prove checks needed by sch.tensorize() (e.g. CompareBufferRegion).

cc @ekalda @lhutton1

lhutton1 commented 1 month ago

cc @Lunderberg

ekalda commented 1 month ago

We decided to roll back to the initial version of this patch since a large number of Relax tests fail when division by zero is disabled in the rewrite rules (so dealing with this is out of scope for this patch).

ekalda commented 1 month ago

Thanks @Anndrey24!