halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.86k stars 1.07k forks source link

Anderson2021 autoscheduler fails with: Condition failed: at_or_inside_block() #8256

Open jansel opened 4 months ago

jansel commented 4 months ago

This code is a cleaned up lowering of part of torch.argmax(torch.adaptive_avg_pool1d(...))

repro.py

import halide as hl

@hl.generator(name="kernel")
class Kernel:
    in_ptr0 = hl.InputBuffer(hl.Float(64), 1)
    out_ptr0 = hl.OutputBuffer(hl.Int(64), 1)

    def generate(g):
        in_ptr0 = g.in_ptr0
        out_ptr0 = g.out_ptr0
        rindex = hl.Var("rindex")
        r0 = rindex % 2
        r1 = rindex // 2
        rdom = hl.RDom([hl.Range(0, 32)])
        tmp3 = hl.Func("tmp3")
        tmp3[rindex] = (3 * r0) // 2
        tmp4 = hl.Func("tmp4")
        tmp4[rindex] = 2 + ((3 * r0) // 2)
        tmp5 = hl.Func("tmp5")
        tmp5[rindex] = tmp3[rindex] < tmp4[rindex]
        tmp7 = hl.Func("tmp7")
        tmp7[rindex] = hl.BoundaryConditions.constant_exterior(in_ptr0, 0)[
            (3 * r1) + ((3 * r0) // 2)
        ]
        tmp9 = hl.Func("tmp9")
        tmp9[rindex] = hl.select(tmp5[rindex], tmp7[rindex], hl.f64(0.0))
        tmp10 = hl.Func("tmp10")
        tmp10[rindex] = 1 + ((3 * r0) // 2)
        tmp11 = hl.Func("tmp11")
        tmp11[rindex] = tmp10[rindex] < tmp4[rindex]
        tmp13 = hl.Func("tmp13")
        tmp13[rindex] = hl.BoundaryConditions.constant_exterior(in_ptr0, 0)[
            1 + (3 * r1) + ((3 * r0) // 2)
        ]
        tmp15 = hl.Func("tmp15")
        tmp15[rindex] = hl.select(tmp11[rindex], tmp13[rindex], hl.f64(0.0))
        tmp16 = hl.Func("tmp16")
        tmp16[rindex] = tmp15[rindex] + tmp9[rindex]
        tmp19 = hl.Func("tmp19")
        tmp19[rindex] = hl.select(tmp5[rindex], hl.f64(1.0), hl.f64(0.0))
        tmp20 = hl.Func("tmp20")
        tmp20[rindex] = hl.select(tmp11[rindex], hl.f64(1.0), hl.f64(0.0))
        tmp21 = hl.Func("tmp21")
        tmp21[rindex] = tmp20[rindex] + tmp19[rindex]
        tmp22 = hl.Func("tmp22")
        tmp22[rindex] = tmp16[rindex] / tmp21[rindex]
        tmp23 = hl.argmax(rdom, tmp22[rdom])[0]
        out_ptr0[hl.Var()] = hl.cast(out_ptr0.type(), tmp23)

        assert g.using_autoscheduler()
        in_ptr0.set_estimates([hl.Range(0, 48)])
        # the range here is actually 1, but setting it to 2 to workaround: https://github.com/halide/Halide/issues/8246
        out_ptr0.set_estimates([hl.Range(0, 2)])

if __name__ == "__main__":
    import sys, tempfile

    with tempfile.TemporaryDirectory() as out:
        sys.argv = [
            "repro.py",
            "-g", "kernel",
            "-o", out,
            "-f", "halide_kernel",
            "-e", "static_library,h,schedule",
            "-p", "/home/jansel/conda/envs/pytorch/lib/libautoschedule_anderson2021.so",
            "target=host-cuda-cuda_capability_86-user_context-strict_float-no_asserts",
            "autoscheduler=Anderson2021",
            "autoscheduler.parallelism=82",
        ]
        hl.main()

(you will need to update the path to libautoschedule_anderson2021.so)

Output:

Unhandled exception: Internal Error at /home/jansel/Halide/src/autoschedulers/anderson2021/GPULoopInfo.cpp:92 triggered by user code at : Condition failed: at_or_inside_block(): 

Traceback (most recent call last):
  File "/home/jansel/pytorch/repro.py", line 72, in <module>
    hl.main()
RuntimeError: Generator failed: -1

The code includes a workaround to #8246 by saying the output size is 2 (when it is actually 1). If I remove that workaround, I get the same error as #8246. I think the workaround is uncovering a different issue, but the two issues are possibly related.

jansel commented 3 months ago

Here is another (larger) example of this same error:

import halide as hl
from math import inf, nan

@hl.generator(name="kernel")
class Kernel:
    in_ptr0 = hl.InputBuffer(hl.Float(32), 2)
    in_ptr1 = hl.InputBuffer(hl.Int(64), 1)
    out_ptr0 = hl.OutputBuffer(hl.Float(32), 1)
    out_ptr1 = hl.OutputBuffer(hl.Float(32), 1)
    out_ptr4 = hl.OutputBuffer(hl.Float(32), 1)
    out_ptr5 = hl.OutputBuffer(hl.Float(32), 1)

    def generate(g):
        in_ptr0 = g.in_ptr0
        in_ptr1 = g.in_ptr1
        out_ptr0 = g.out_ptr0
        out_ptr1 = g.out_ptr1
        out_ptr4 = g.out_ptr4
        out_ptr5 = g.out_ptr5
        h0 = hl.Var("h0")
        rdom = hl.RDom([hl.Range(0, 5)])
        hr1 = rdom[0]
        tmp0 = hl.Func("tmp0")
        tmp0[h0] = in_ptr0[
            0,
            h0,
        ]
        tmp1 = hl.Func("tmp1")
        tmp1[h0] = in_ptr0[
            1,
            h0,
        ]
        tmp2 = hl.Func("tmp2")
        tmp2[h0] = (
            hl.select((tmp0[h0] > tmp1[h0]) | hl.is_nan(tmp0[h0]), tmp0[h0], tmp1[h0])
            if tmp0.type().is_float()
            else hl.max(tmp0[h0], tmp1[h0])
        )
        tmp3 = hl.Func("tmp3")
        tmp3[h0] = in_ptr0[
            2,
            h0,
        ]
        tmp4 = hl.Func("tmp4")
        tmp4[h0] = (
            hl.select((tmp2[h0] > tmp3[h0]) | hl.is_nan(tmp2[h0]), tmp2[h0], tmp3[h0])
            if tmp2.type().is_float()
            else hl.max(tmp2[h0], tmp3[h0])
        )
        tmp5 = hl.Func("tmp5")
        tmp5[h0] = in_ptr0[
            3,
            h0,
        ]
        tmp6 = hl.Func("tmp6")
        tmp6[h0] = (
            hl.select((tmp4[h0] > tmp5[h0]) | hl.is_nan(tmp4[h0]), tmp4[h0], tmp5[h0])
            if tmp4.type().is_float()
            else hl.max(tmp4[h0], tmp5[h0])
        )
        tmp7 = hl.Func("tmp7")
        tmp7[h0] = in_ptr0[
            4,
            h0,
        ]
        tmp8 = hl.Func("tmp8")
        tmp8[h0] = (
            hl.select((tmp6[h0] > tmp7[h0]) | hl.is_nan(tmp6[h0]), tmp6[h0], tmp7[h0])
            if tmp6.type().is_float()
            else hl.max(tmp6[h0], tmp7[h0])
        )
        out_ptr0[h0,] = hl.cast(hl.Float(32), tmp8[h0])
        tmp9 = hl.Func("tmp9")
        tmp9[h0] = tmp0[h0] - tmp8[h0]
        tmp10 = hl.Func("tmp10")
        tmp10[h0] = (
            hl.fast_exp(hl.cast(hl.Float(32), tmp9[h0]))
            if tmp9.type().bits() <= 32
            else hl.exp(tmp9[h0])
        )
        tmp11 = hl.Func("tmp11")
        tmp11[h0] = tmp1[h0] - tmp8[h0]
        tmp12 = hl.Func("tmp12")
        tmp12[h0] = (
            hl.fast_exp(hl.cast(hl.Float(32), tmp11[h0]))
            if tmp11.type().bits() <= 32
            else hl.exp(tmp11[h0])
        )
        tmp13 = hl.Func("tmp13")
        tmp13[h0] = tmp10[h0] + tmp12[h0]
        tmp14 = hl.Func("tmp14")
        tmp14[h0] = tmp3[h0] - tmp8[h0]
        tmp15 = hl.Func("tmp15")
        tmp15[h0] = (
            hl.fast_exp(hl.cast(hl.Float(32), tmp14[h0]))
            if tmp14.type().bits() <= 32
            else hl.exp(tmp14[h0])
        )
        tmp16 = hl.Func("tmp16")
        tmp16[h0] = tmp13[h0] + tmp15[h0]
        tmp17 = hl.Func("tmp17")
        tmp17[h0] = tmp5[h0] - tmp8[h0]
        tmp18 = hl.Func("tmp18")
        tmp18[h0] = (
            hl.fast_exp(hl.cast(hl.Float(32), tmp17[h0]))
            if tmp17.type().bits() <= 32
            else hl.exp(tmp17[h0])
        )
        tmp19 = hl.Func("tmp19")
        tmp19[h0] = tmp16[h0] + tmp18[h0]
        tmp20 = hl.Func("tmp20")
        tmp20[h0] = tmp7[h0] - tmp8[h0]
        tmp21 = hl.Func("tmp21")
        tmp21[h0] = (
            hl.fast_exp(hl.cast(hl.Float(32), tmp20[h0]))
            if tmp20.type().bits() <= 32
            else hl.exp(tmp20[h0])
        )
        tmp22 = hl.Func("tmp22")
        tmp22[h0] = tmp19[h0] + tmp21[h0]
        tmp23 = hl.Func("tmp23")
        tmp23[h0] = hl.log(tmp22[h0])
        out_ptr1[h0,] = hl.cast(hl.Float(32), tmp23[h0])
        tmp24 = hl.Func("tmp24")
        tmp24[h0] = in_ptr1[h0,]
        tmp25 = hl.Func("tmp25")
        tmp25 = hl.cast(hl.Int(64), -100)
        tmp26 = hl.Func("tmp26")
        tmp26[h0] = tmp24[h0] != tmp25
        tmp27 = hl.Func("tmp27")
        tmp27[h0] = hl.cast(hl.Int(64), tmp26[h0])
        tmp28 = hl.Func("tmp28")
        tmp28 = hl.sum(rdom, tmp27[hr1])
        tmp29 = hl.Func("tmp29")
        tmp29 = hl.cast(hl.Int(64), 0)
        tmp30 = hl.Func("tmp30")
        tmp30[h0] = hl.select(tmp26[h0], tmp24[h0], hl.cast(tmp24.type(), tmp29))
        tmp31 = hl.Func("tmp31")
        tmp31 = 5
        tmp32 = hl.Func("tmp32")
        tmp32[h0] = tmp30[h0] + tmp31
        tmp33 = hl.Func("tmp33")
        tmp33[h0] = tmp30[h0] < 0
        tmp34 = hl.Func("tmp34")
        tmp34[h0] = hl.select(tmp33[h0], tmp32[h0], hl.cast(tmp32.type(), tmp30[h0]))
        tmp35 = hl.Func("tmp35")
        tmp35[h0] = hl.cast(hl.Int(32), tmp34[h0])
        tmp36 = hl.Func("tmp36")
        tmp36[h0] = hl.clamp(tmp35[h0], 0, 4)
        tmp37 = hl.Func("tmp37")
        tmp37[h0] = in_ptr0[
            tmp36[h0],
            h0,
        ]
        tmp38 = hl.Func("tmp38")
        tmp38[h0] = tmp37[h0] - tmp8[h0]
        tmp39 = hl.Func("tmp39")
        tmp39[h0] = tmp38[h0] - tmp23[h0]
        tmp40 = hl.Func("tmp40")
        tmp40[h0] = -tmp39[h0]
        tmp41 = hl.Func("tmp41")
        tmp41 = hl.cast(hl.Float(32), hl.f64(0.0))
        tmp42 = hl.Func("tmp42")
        tmp42[h0] = hl.select(tmp26[h0], tmp40[h0], hl.cast(tmp40.type(), tmp41))
        tmp43 = hl.Func("tmp43")
        tmp43 = hl.sum(rdom, tmp42[hr1])
        tmp44 = hl.Func("tmp44")
        tmp44 = hl.cast(hl.Float(32), tmp28)
        out_ptr4[hl.Var(),] = hl.cast(hl.Float(32), tmp44)
        tmp45 = hl.Func("tmp45")
        tmp45 = tmp43 / tmp44
        out_ptr5[hl.Var(),] = hl.cast(hl.Float(32), tmp45)

        assert g.using_autoscheduler()
        in_ptr0.dim(0).set_min(0)
        in_ptr0.dim(0).set_stride(1)
        in_ptr0.dim(0).set_extent(5)
        in_ptr0.dim(1).set_min(0)
        in_ptr0.dim(1).set_stride(5)
        in_ptr0.dim(1).set_extent(5)
        in_ptr0.set_estimates([hl.Range(0, 5), hl.Range(0, 5)])
        in_ptr1.dim(0).set_min(0)
        in_ptr1.dim(0).set_stride(1)
        in_ptr1.dim(0).set_extent(5)
        in_ptr1.set_estimates([hl.Range(0, 5)])
        out_ptr0.set_estimates([hl.Range(0, 5)])
        out_ptr1.set_estimates([hl.Range(0, 5)])
        out_ptr4.set_estimates([hl.Range(0, 2)])
        out_ptr5.set_estimates([hl.Range(0, 2)])

if __name__ == "__main__":
    import sys, tempfile

    with tempfile.TemporaryDirectory() as out:
        sys.argv = [
            "repro.py",
            "-g",
            "kernel",
            "-o",
            out,
            "-f",
            "halide_kernel",
            "-e",
            "static_library,h,schedule",
            "-p",
            "/home/jansel/conda/envs/pytorch/lib/libautoschedule_anderson2021.so",
            "target=host-cuda-cuda_capability_86-user_context-strict_float-no_runtime-no_asserts",
            "autoscheduler=Anderson2021",
            "autoscheduler.parallelism=82",
        ]
        hl.main()

else:
    hl.load_plugin(
        "/home/jansel/conda/envs/pytorch/lib/libautoschedule_anderson2021.so"
    )
    target = hl.Target(
        "host-cuda-cuda_capability_86-user_context-strict_float-no_runtime-no_asserts"
    )
    autoscheduler = hl.AutoschedulerParams("Anderson2021", {"parallelism": 82})
    with hl.GeneratorContext(target, autoscheduler):
        gen = Kernel()
        pipeline = gen._build_pipeline()
        # gen.compile_to_callable() does not run the autoscheduler
        pipeline.apply_autoscheduler(target, autoscheduler)
        kernel = pipeline.compile_to_callable(
            [
                gen._get_input_parameter(a.name)._to_argument()
                for a in gen._get_arginfos()
                if a.dir == hl.ArgInfoDirection.Input
            ],
            target,
        )