halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.86k stars 1.07k forks source link

Adams2019 tries to tile RVars that it's not allowed to #8278

Open jansel opened 3 months ago

jansel commented 3 months ago

Repro:

import halide as hl

@hl.generator(name="kernel")
class Kernel:
    in_ptr0 = hl.InputBuffer(hl.Float(32), 1)
    out_ptr0 = hl.OutputBuffer(hl.Float(32), 1)

    def generate(g):
        in_ptr0 = g.in_ptr0
        out_ptr0 = g.out_ptr0
        xindex = hl.Var('xindex')
        yindex = hl.Var('yindex')
        x2 = xindex
        y0 = yindex % 192
        y1 = (yindex // 192)
        odom = hl.RDom([hl.Range(0, 169), hl.Range(0, 3072)])
        xindex_odom = odom.x
        yindex_odom = odom.y
        x2_odom = xindex_odom
        y0_odom = yindex_odom % 192
        y1_odom = (yindex_odom // 192)
        tmp0 = hl.Func('tmp0')
        tmp0[xindex, yindex] = in_ptr0[y0 + (192*x2) + (32448*y1)]
        tmp1 = hl.Func('tmp1')
        tmp1[xindex, yindex] = hl.max(tmp0[xindex, yindex], 0)
        out_ptr0[hl.Var()] = hl.undef(out_ptr0.type())
        out_ptr0[x2_odom + (169*y0_odom) + (64896*y1_odom)] = hl.cast(out_ptr0.type(), tmp1[xindex_odom, yindex_odom])

        assert g.using_autoscheduler()
        in_ptr0.set_estimates([hl.Range(0, 519168)])
        out_ptr0.set_estimates([hl.Range(0, 519168)])

if __name__ == "__main__":
    import sys, tempfile
    with tempfile.TemporaryDirectory() as out:
        sys.argv = ['repro.py',
                    '-g', 'kernel',
                    '-o', out,
                    '-f', 'halide_kernel',
                    '-e', 'static_library,h,schedule',
                    '-p', '/home/jansel/conda/envs/pytorch/lib/libautoschedule_adams2019.so',
                    'target=host-strict_float-no_runtime-no_asserts',
                    'autoscheduler=Adams2019', 'autoscheduler.parallelism=8']
        hl.main()

Output

Unhandled exception: Error: In schedule for out_ptr0.update(0), can't reorder RVars r6$yi and r6$x because it may change the meaning of the algorithm.

Traceback (most recent call last):
  File "/home/jansel/pytorch/repro.py", line 45, in <module>
    hl.main()
RuntimeError: Generator failed: -1