halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.78k stars 1.07k forks source link

Adams2019 autoscheduler fails with "Ran out of legal states with beam size 32" #8308

Open jansel opened 1 week ago

jansel commented 1 week ago

Repro:

import halide as hl

@hl.generator(name="kernel")
class Kernel:
    in_ptr0 = hl.InputBuffer(hl.Int(64), 1)
    in_ptr1 = hl.InputBuffer(hl.Float(32), 1)
    ks0 = hl.InputScalar(hl.Int(64))
    ks1 = hl.InputScalar(hl.Int(64))
    out_ptr0 = hl.OutputBuffer(hl.Float(32), 1)

    def generate(g):
        in_ptr0 = g.in_ptr0
        in_ptr1 = g.in_ptr1
        ks0 = g.ks0
        ks1 = g.ks1
        out_ptr0 = g.out_ptr0
        xindex = hl.Var('xindex')
        x1 = (xindex // 64)
        x0 = xindex % 64
        x2 = xindex
        tmp0 = hl.Func('tmp0')
        tmp0[xindex] = in_ptr0[x1]
        tmp1 = hl.Func('tmp1')
        tmp1 = ks0
        tmp2 = hl.Func('tmp2')
        tmp2[xindex] = tmp0[xindex] + tmp1
        tmp3 = hl.Func('tmp3')
        tmp3[xindex] = tmp0[xindex] < 0
        tmp4 = hl.Func('tmp4')
        tmp4[xindex] = hl.select(tmp3[xindex], tmp2[xindex], hl.cast(tmp2.type(), tmp0[xindex]))
        tmp5 = hl.Func('tmp5')
        tmp5[xindex] = hl.cast(hl.Int(32), tmp4[xindex])
        tmp6 = hl.Func('tmp6')
        tmp6[xindex] = hl.clamp(tmp5[xindex], 0, hl.cast(tmp5.type(), (-1) + ks0))
        tmp7 = hl.Func('tmp7')
        tmp7[xindex] = hl.BoundaryConditions.constant_exterior(in_ptr1, 0)[x0 + (64*(tmp6[xindex]))]
        out_ptr0[x2] = hl.cast(out_ptr0.type(), tmp7[xindex])

        assert g.using_autoscheduler()
        in_ptr0.set_estimates([hl.Range(0, 1)])
        in_ptr1.set_estimates([hl.Range(0, 2048)])
        ks0.set_estimate(3)
        ks1.set_estimate(1)
        out_ptr0.set_estimates([hl.Range(0, 1)])

if __name__ == "__main__":
    import sys, tempfile
    with tempfile.TemporaryDirectory() as out:
        sys.argv = ['repro.py', '-g', 'kernel', '-o', out, '-f', 'halide_kernel', '-e', 'static_library,h,schedule', '-p', '/home/jansel/conda/envs/pytorch/lib/libautoschedule_adams2019.so', 'target=host-strict_float-no_runtime-no_asserts-large_buffers', 'autoscheduler=Adams2019', 'autoscheduler.parallelism=8']
        hl.main()

else:
    hl.load_plugin('/home/jansel/conda/envs/pytorch/lib/libautoschedule_adams2019.so')
    target = hl.Target('host-strict_float-no_runtime-no_asserts-large_buffers')
    autoscheduler = hl.AutoschedulerParams('Adams2019', {'parallelism': 8})
    with hl.GeneratorContext(target, autoscheduler):
        gen = Kernel()
        pipeline = gen._build_pipeline()
        # gen.compile_to_callable() does not run the autoscheduler
        pipeline.apply_autoscheduler(target, autoscheduler)
        kernel = pipeline.compile_to_callable([
                gen._get_input_parameter(a.name)._to_argument()
                for a in gen._get_arginfos()
                if a.dir == hl.ArgInfoDirection.Input
            ], target)

Output:

State with cost 3.82387e-06:

realize: out_ptr0
realize: tmp7
realize: tmp6
realize: tmp5
realize: tmp4
realize: tmp2
realize: tmp3
realize: tmp0
tmp0 1 (0, 0) t p
 tmp0 1 (0, 0) t
  tmp0 16vc (0, 0) *
tmp3 1 (0, 0) t p
 tmp3 1 (0, 0) t
  tmp3 16vc (0, 0) *
tmp2 1 (0, 0) t p
 tmp2 1 (0, 0) t
  tmp2 16vc (0, 0) *
tmp4 1 (0, 0) t p
 tmp4 1 (0, 0) t
  tmp4 16vc (0, 0) *
tmp5 1 (0, 0) t p
 tmp5 1 (0, 0) t
  tmp5 16vc (0, 0) *
tmp6 1 (0, 0) t p
 tmp6 1 (0, 0) t
  tmp6 16vc (0, 0) *
tmp7 1 (0, 0) t p
 tmp7 1 (0, 0) t
  tmp7 16vc (0, 0) *
out_ptr0 1 (0, 0) t p
 out_ptr0 1 (0, 0) t
  out_ptr0 16vc (0, 0) *
Unhandled exception: Internal Error at /home/jansel/Halide/src/autoschedulers/adams2019/AutoSchedule.cpp:295 triggered by user code at : Ran out of legal states with beam size 32

Traceback (most recent call last):
  File "/home/jansel/pytorch/repro.py", line 71, in <module>
    hl.main()
RuntimeError: Generator failed: -1