halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.77k stars 1.07k forks source link

Error: Atomic predicated store is not supported from Li2018 autoscheduler #8280

Closed jansel closed 1 week ago

jansel commented 3 weeks ago

This only fails with the Li2018 autoscheduler

import halide as hl

@hl.generator(name="kernel")
class Kernel:
    in_ptr0 = hl.InputBuffer(hl.Int(64), 1)
    in_ptr1 = hl.InputBuffer(hl.Int(64), 1)
    out_ptr0 = hl.OutputBuffer(hl.Bool(), 1)
    out_ptr1 = hl.OutputBuffer(hl.Int(64), 1)
    out_ptr3 = hl.OutputBuffer(hl.Int(64), 1)
    out_ptr4 = hl.OutputBuffer(hl.Int(64), 1)
    out_ptr5 = hl.OutputBuffer(hl.Int(64), 1)

    def generate(g):
        in_ptr0 = g.in_ptr0
        in_ptr1 = g.in_ptr1
        out_ptr0 = g.out_ptr0
        out_ptr1 = g.out_ptr1
        out_ptr3 = g.out_ptr3
        out_ptr4 = g.out_ptr4
        out_ptr5 = g.out_ptr5
        rindex = hl.Var('rindex')
        r0 = rindex
        rdom = hl.RDom([hl.Range(0, 473)])
        tmp0 = hl.Func('tmp0')
        tmp0[rindex] = in_ptr0[r0]
        tmp1 = hl.cast(hl.Int(64), 0)
        tmp2 = hl.Func('tmp2')
        tmp2[rindex] = tmp0[rindex] != tmp1
        out_ptr0[r0] = hl.cast(out_ptr0.type(), tmp2[rindex])
        tmp3 = (False != 0)
        tmp4 = hl.Func('tmp4')
        tmp4[rindex] = tmp2[rindex] == tmp3
        tmp5 = 2869
        tmp6 = hl.Func('tmp6')
        tmp6[rindex] = tmp0[rindex] + tmp5
        tmp7 = hl.Func('tmp7')
        tmp7[rindex] = tmp0[rindex] < 0
        tmp8 = hl.Func('tmp8')
        tmp8[rindex] = hl.select(tmp7[rindex], tmp6[rindex], tmp0[rindex])
        tmp9 = hl.Func('tmp9')
        tmp9[rindex] = hl.cast(hl.Int(32), tmp8[rindex])
        tmp10 = hl.Func('tmp10')
        tmp10[rindex] = in_ptr1[hl.clamp(tmp9[rindex], 0, 2868)]
        tmp11 = hl.Func('tmp11')
        tmp11[rindex] = hl.select(tmp4[rindex], tmp1, tmp10[rindex])
        out_ptr1[r0] = hl.cast(out_ptr1.type(), tmp11[rindex])
        tmp12 = hl.sum(rdom, tmp11[rdom])
        out_ptr3[hl.Var()] = hl.cast(out_ptr3.type(), tmp12)
        tmp13 = hl.Func('tmp13')
        tmp13[rindex] = hl.cast(hl.Int(64), tmp2[rindex])
        tmp14 = hl.sum(rdom, tmp13[rdom])
        out_ptr4[hl.Var()] = hl.cast(out_ptr4.type(), tmp14)
        out_ptr5[hl.Var()] = hl.cast(out_ptr5.type(), tmp12)

        assert g.using_autoscheduler()
        in_ptr0.set_estimates([hl.Range(0, 473)])
        in_ptr1.set_estimates([hl.Range(0, 2869)])
        out_ptr0.set_estimates([hl.Range(0, 473)])
        out_ptr1.set_estimates([hl.Range(0, 473)])
        out_ptr3.set_estimates([hl.Range(0, 1)])
        out_ptr4.set_estimates([hl.Range(0, 1)])
        out_ptr5.set_estimates([hl.Range(0, 1)])

if __name__ == "__main__":
    import sys, tempfile
    with tempfile.TemporaryDirectory() as out:
        sys.argv = ['repro.py',
                    '-g', 'kernel',
                    '-o', out,
                    '-f', 'halide_kernel',
                    '-e', 'static_library,h,schedule',
                    '-p', '/home/jansel/conda/envs/pytorch/lib/libautoschedule_li2018.so',
                    'target=host-strict_float-no_runtime-no_asserts',
                    'autoscheduler=Li2018', 'autoscheduler.parallelism=8']
        hl.main()

Output:

Unhandled exception: Error: Atomic predicated store is not supported.

Traceback (most recent call last):
  File "/home/jansel/pytorch/repro.py", line 76, in <module>
    hl.main()
RuntimeError: Generator failed: -1
abadams commented 3 weeks ago

Tracking issue for this was #4298