halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.91k stars 1.07k forks source link

32 minute compile time for max_pool2d_with_indices #8429

Open jansel opened 1 month ago

jansel commented 1 month ago

This example takes 32 minutes to compile, while typical kernels take seconds (not minutes). I suspect it is hitting some sort of pathological case in Halide.

repro.py

import halide as hl
from torch._inductor.runtime import halide_helpers
from math import inf, nan

@hl.generator(name="kernel")
class Kernel:
    in_ptr0 = hl.InputBuffer(hl.Float(32), 5)
    out_ptr0 = hl.OutputBuffer(hl.Float(32), 4)
    out_ptr2 = hl.OutputBuffer(hl.Int(64), 4)

    def generate(g):
        in_ptr0 = g.in_ptr0
        out_ptr0 = g.out_ptr0
        out_ptr2 = g.out_ptr2
        h0 = hl.Var("h0")
        h1 = hl.Var("h1")
        h2 = hl.Var("h2")
        h3 = hl.Var("h3")
        tmp0 = hl.Func("tmp0")
        tmp0[h0, h1, h2, h3] = in_ptr0[0, h0, h1, h2, h3]
        tmp1 = hl.Func("tmp1")
        tmp1[h0, h1, h2, h3] = in_ptr0[1, h0, h1, h2, h3]
        tmp2 = hl.Func("tmp2")
        tmp2[h0, h1, h2, h3] = (
            hl.select(
                (tmp1[h0, h1, h2, h3] > hl.cast(tmp1.type(), tmp0[h0, h1, h2, h3]))
                | hl.is_nan(tmp1[h0, h1, h2, h3]),
                tmp1[h0, h1, h2, h3],
                hl.cast(tmp1.type(), tmp0[h0, h1, h2, h3]),
            )
            if tmp1.type().is_float()
            else hl.max(
                tmp1[h0, h1, h2, h3], hl.cast(tmp1.type(), tmp0[h0, h1, h2, h3])
            )
        )
        tmp3 = hl.Func("tmp3")
        tmp3[h0, h1, h2, h3] = in_ptr0[0, 1 + h0, h1, h2, h3]
        tmp4 = hl.Func("tmp4")
        tmp4[h0, h1, h2, h3] = (
            hl.select(
                (tmp3[h0, h1, h2, h3] > hl.cast(tmp3.type(), tmp2[h0, h1, h2, h3]))
                | hl.is_nan(tmp3[h0, h1, h2, h3]),
                tmp3[h0, h1, h2, h3],
                hl.cast(tmp3.type(), tmp2[h0, h1, h2, h3]),
            )
            if tmp3.type().is_float()
            else hl.max(
                tmp3[h0, h1, h2, h3], hl.cast(tmp3.type(), tmp2[h0, h1, h2, h3])
            )
        )
        tmp5 = hl.Func("tmp5")
        tmp5[h0, h1, h2, h3] = in_ptr0[1, 13 + h0, h1, h2, h3]
        tmp6 = hl.Func("tmp6")
        tmp6[h0, h1, h2, h3] = (
            hl.select(
                (tmp5[h0, h1, h2, h3] > hl.cast(tmp5.type(), tmp4[h0, h1, h2, h3]))
                | hl.is_nan(tmp5[h0, h1, h2, h3]),
                tmp5[h0, h1, h2, h3],
                hl.cast(tmp5.type(), tmp4[h0, h1, h2, h3]),
            )
            if tmp5.type().is_float()
            else hl.max(
                tmp5[h0, h1, h2, h3], hl.cast(tmp5.type(), tmp4[h0, h1, h2, h3])
            )
        )
        tmp7 = hl.Func("tmp7")
        tmp7[h0, h1, h2, h3] = in_ptr0[0, 14 + h0, h1, h2, h3]
        tmp8 = hl.Func("tmp8")
        tmp8[h0, h1, h2, h3] = (
            hl.select(
                (tmp7[h0, h1, h2, h3] > hl.cast(tmp7.type(), tmp6[h0, h1, h2, h3]))
                | hl.is_nan(tmp7[h0, h1, h2, h3]),
                tmp7[h0, h1, h2, h3],
                hl.cast(tmp7.type(), tmp6[h0, h1, h2, h3]),
            )
            if tmp7.type().is_float()
            else hl.max(
                tmp7[h0, h1, h2, h3], hl.cast(tmp7.type(), tmp6[h0, h1, h2, h3])
            )
        )
        tmp9 = hl.Func("tmp9")
        tmp9[h0, h1, h2, h3] = in_ptr0[1, 14 + h0, h1, h2, h3]
        tmp10 = hl.Func("tmp10")
        tmp10[h0, h1, h2, h3] = (
            hl.select(
                (tmp9[h0, h1, h2, h3] > hl.cast(tmp9.type(), tmp8[h0, h1, h2, h3]))
                | hl.is_nan(tmp9[h0, h1, h2, h3]),
                tmp9[h0, h1, h2, h3],
                hl.cast(tmp9.type(), tmp8[h0, h1, h2, h3]),
            )
            if tmp9.type().is_float()
            else hl.max(
                tmp9[h0, h1, h2, h3], hl.cast(tmp9.type(), tmp8[h0, h1, h2, h3])
            )
        )
        tmp11 = hl.Func("tmp11")
        tmp11[h0, h1, h2, h3] = in_ptr0[0, h0, 1 + h1, h2, h3]
        tmp12 = hl.Func("tmp12")
        tmp12[h0, h1, h2, h3] = (
            hl.select(
                (tmp11[h0, h1, h2, h3] > hl.cast(tmp11.type(), tmp10[h0, h1, h2, h3]))
                | hl.is_nan(tmp11[h0, h1, h2, h3]),
                tmp11[h0, h1, h2, h3],
                hl.cast(tmp11.type(), tmp10[h0, h1, h2, h3]),
            )
            if tmp11.type().is_float()
            else hl.max(
                tmp11[h0, h1, h2, h3], hl.cast(tmp11.type(), tmp10[h0, h1, h2, h3])
            )
        )
        tmp13 = hl.Func("tmp13")
        tmp13[h0, h1, h2, h3] = in_ptr0[1, h0, 1 + h1, h2, h3]
        tmp14 = hl.Func("tmp14")
        tmp14[h0, h1, h2, h3] = (
            hl.select(
                (tmp13[h0, h1, h2, h3] > hl.cast(tmp13.type(), tmp12[h0, h1, h2, h3]))
                | hl.is_nan(tmp13[h0, h1, h2, h3]),
                tmp13[h0, h1, h2, h3],
                hl.cast(tmp13.type(), tmp12[h0, h1, h2, h3]),
            )
            if tmp13.type().is_float()
            else hl.max(
                tmp13[h0, h1, h2, h3], hl.cast(tmp13.type(), tmp12[h0, h1, h2, h3])
            )
        )
        tmp15 = hl.Func("tmp15")
        tmp15[h0, h1, h2, h3] = in_ptr0[0, 1 + h0, 1 + h1, h2, h3]
        tmp16 = hl.Func("tmp16")
        tmp16[h0, h1, h2, h3] = (
            hl.select(
                (tmp15[h0, h1, h2, h3] > hl.cast(tmp15.type(), tmp14[h0, h1, h2, h3]))
                | hl.is_nan(tmp15[h0, h1, h2, h3]),
                tmp15[h0, h1, h2, h3],
                hl.cast(tmp15.type(), tmp14[h0, h1, h2, h3]),
            )
            if tmp15.type().is_float()
            else hl.max(
                tmp15[h0, h1, h2, h3], hl.cast(tmp15.type(), tmp14[h0, h1, h2, h3])
            )
        )
        out_ptr0[h0, h1, h2, h3] = hl.cast(hl.Float(32), tmp16[h0, h1, h2, h3])
        tmp17 = hl.Func("tmp17")
        tmp17[h0, h1, h2, h3] = tmp1[h0, h1, h2, h3] > tmp0[h0, h1, h2, h3]
        tmp18 = hl.Func("tmp18")
        tmp18[()] = hl.cast(hl.Int(8), 1)
        tmp19 = hl.Func("tmp19")
        tmp19[()] = hl.cast(hl.Int(8), 0)
        tmp20 = hl.Func("tmp20")
        tmp20[h0, h1, h2, h3] = hl.select(
            tmp17[h0, h1, h2, h3], tmp18[()], hl.cast(tmp18.type(), tmp19[()])
        )
        tmp21 = hl.Func("tmp21")
        tmp21[h0, h1, h2, h3] = tmp3[h0, h1, h2, h3] > tmp2[h0, h1, h2, h3]
        tmp22 = hl.Func("tmp22")
        tmp22[()] = hl.cast(hl.Int(8), 2)
        tmp23 = hl.Func("tmp23")
        tmp23[h0, h1, h2, h3] = hl.select(
            tmp21[h0, h1, h2, h3],
            tmp22[()],
            hl.cast(tmp22.type(), tmp20[h0, h1, h2, h3]),
        )
        tmp24 = hl.Func("tmp24")
        tmp24[h0, h1, h2, h3] = tmp5[h0, h1, h2, h3] > tmp4[h0, h1, h2, h3]
        tmp25 = hl.Func("tmp25")
        tmp25[()] = hl.cast(hl.Int(8), 3)
        tmp26 = hl.Func("tmp26")
        tmp26[h0, h1, h2, h3] = hl.select(
            tmp24[h0, h1, h2, h3],
            tmp25[()],
            hl.cast(tmp25.type(), tmp23[h0, h1, h2, h3]),
        )
        tmp27 = hl.Func("tmp27")
        tmp27[h0, h1, h2, h3] = tmp7[h0, h1, h2, h3] > tmp6[h0, h1, h2, h3]
        tmp28 = hl.Func("tmp28")
        tmp28[()] = hl.cast(hl.Int(8), 4)
        tmp29 = hl.Func("tmp29")
        tmp29[h0, h1, h2, h3] = hl.select(
            tmp27[h0, h1, h2, h3],
            tmp28[()],
            hl.cast(tmp28.type(), tmp26[h0, h1, h2, h3]),
        )
        tmp30 = hl.Func("tmp30")
        tmp30[h0, h1, h2, h3] = tmp9[h0, h1, h2, h3] > tmp8[h0, h1, h2, h3]
        tmp31 = hl.Func("tmp31")
        tmp31[()] = hl.cast(hl.Int(8), 5)
        tmp32 = hl.Func("tmp32")
        tmp32[h0, h1, h2, h3] = hl.select(
            tmp30[h0, h1, h2, h3],
            tmp31[()],
            hl.cast(tmp31.type(), tmp29[h0, h1, h2, h3]),
        )
        tmp33 = hl.Func("tmp33")
        tmp33[h0, h1, h2, h3] = tmp11[h0, h1, h2, h3] > tmp10[h0, h1, h2, h3]
        tmp34 = hl.Func("tmp34")
        tmp34[()] = hl.cast(hl.Int(8), 6)
        tmp35 = hl.Func("tmp35")
        tmp35[h0, h1, h2, h3] = hl.select(
            tmp33[h0, h1, h2, h3],
            tmp34[()],
            hl.cast(tmp34.type(), tmp32[h0, h1, h2, h3]),
        )
        tmp36 = hl.Func("tmp36")
        tmp36[h0, h1, h2, h3] = tmp13[h0, h1, h2, h3] > tmp12[h0, h1, h2, h3]
        tmp37 = hl.Func("tmp37")
        tmp37[()] = hl.cast(hl.Int(8), 7)
        tmp38 = hl.Func("tmp38")
        tmp38[h0, h1, h2, h3] = hl.select(
            tmp36[h0, h1, h2, h3],
            tmp37[()],
            hl.cast(tmp37.type(), tmp35[h0, h1, h2, h3]),
        )
        tmp39 = hl.Func("tmp39")
        tmp39[h0, h1, h2, h3] = tmp15[h0, h1, h2, h3] > tmp14[h0, h1, h2, h3]
        tmp40 = hl.Func("tmp40")
        tmp40[()] = hl.cast(hl.Int(8), 8)
        tmp41 = hl.Func("tmp41")
        tmp41[h0, h1, h2, h3] = hl.select(
            tmp39[h0, h1, h2, h3],
            tmp40[()],
            hl.cast(tmp40.type(), tmp38[h0, h1, h2, h3]),
        )
        tmp42 = hl.Func("tmp42")
        tmp42[()] = hl.cast(hl.Int(32), 3)
        tmp43 = hl.Func("tmp43")
        tmp43[h0, h1, h2, h3] = hl.floor(
            hl.cast(hl.Float(max(32, tmp41.type().bits())), tmp41[h0, h1, h2, h3])
            / tmp42[()]
        )
        tmp44 = hl.Func("tmp44")
        tmp44[h0, h1, h2, h3] = tmp43[h0, h1, h2, h3] * tmp42[()]
        tmp45 = hl.Func("tmp45")
        tmp45[h0, h1, h2, h3] = tmp41[h0, h1, h2, h3] - tmp44[h0, h1, h2, h3]
        tmp46 = hl.Func("tmp46")
        tmp46[h1] = 2 * h1
        tmp47 = hl.Func("tmp47")
        tmp47[h0, h1, h2, h3] = tmp46[h1] + tmp43[h0, h1, h2, h3]
        tmp48 = hl.Func("tmp48")
        tmp48[h0] = 2 * h0
        tmp49 = hl.Func("tmp49")
        tmp49[h0, h1, h2, h3] = tmp48[h0] + tmp45[h0, h1, h2, h3]
        tmp50 = hl.Func("tmp50")
        tmp50[()] = hl.cast(hl.Int(64), 27)
        tmp51 = hl.Func("tmp51")
        tmp51[h0, h1, h2, h3] = tmp47[h0, h1, h2, h3] * tmp50[()]
        tmp52 = hl.Func("tmp52")
        tmp52[h0, h1, h2, h3] = tmp51[h0, h1, h2, h3] + tmp49[h0, h1, h2, h3]
        out_ptr2[h0, h1, h2, h3] = hl.cast(hl.Int(64), tmp52[h0, h1, h2, h3])

        assert g.using_autoscheduler()
        in_ptr0.dim(0).set_min(0)
        in_ptr0.dim(0).set_stride(1)
        in_ptr0.dim(0).set_extent(2)
        in_ptr0.dim(1).set_min(0)
        in_ptr0.dim(1).set_stride(2)
        in_ptr0.dim(1).set_extent(13)
        in_ptr0.dim(2).set_min(0)
        in_ptr0.dim(2).set_stride(54)
        in_ptr0.dim(2).set_extent(13)
        in_ptr0.dim(3).set_min(0)
        in_ptr0.dim(3).set_stride(729)
        in_ptr0.dim(3).set_extent(192)
        in_ptr0.dim(4).set_min(0)
        in_ptr0.dim(4).set_stride(139968)
        in_ptr0.dim(4).set_extent(128)
        in_ptr0.set_estimates(
            [
                hl.Range(0, 2),
                hl.Range(0, 13),
                hl.Range(0, 13),
                hl.Range(0, 192),
                hl.Range(0, 128),
            ]
        )
        out_ptr0.set_estimates(
            [hl.Range(0, 13), hl.Range(0, 13), hl.Range(0, 192), hl.Range(0, 128)]
        )
        out_ptr2.set_estimates(
            [hl.Range(0, 13), hl.Range(0, 13), hl.Range(0, 192), hl.Range(0, 128)]
        )

if __name__ == "__main__":
    import sys, tempfile

    with tempfile.TemporaryDirectory() as out:
        sys.argv = [
            "repro.py",
            "-g",
            "kernel",
            "-o",
            out,
            "-f",
            "halide_kernel",
            "-e",
            "static_library,h,schedule",
            "-p",
            "/home/jansel/conda/envs/pytorch/lib/python3.12/site-packages/halide/lib64/libautoschedule_anderson2021.so",
            "target=host-cuda-cuda_capability_86-user_context-strict_float-no_runtime-no_asserts",
            "autoscheduler=Anderson2021",
            "autoscheduler.parallelism=82",
        ]
        hl.main()

cc @alexreinking this example coming from:

python benchmarks/dynamo/microbenchmarks/operatorbench.py --inductor-config autotune --inductor-config halide --op aten.max_pool2d_with_indices.default --max-samples 1 --start-idx 4

on https://github.com/pytorch/pytorch/pull/136809

abadams commented 1 month ago

So it takes 32 minutes... but still successfully compiles? Interesting. Maybe there's a lurking pass with exponential complexity for this example.

jansel commented 1 month ago

Yeah, it finishes and runs correctly.

abadams commented 1 month ago

Looks like it's not compilation proper, but rather the anderson autoscheduler getting stuck enumerating a combinatorial number of tiling options, which is a bit absurd given that this entire pipeline seems to be elementwise other than accesses to the input buffer.

A workaround would be to ask the autoscheduler to do a lot less by generating an Expr instead of a Func for anything that has no update definition and is either consumed elementwise or is an op that is cheaper than a load (e.g. tmp48).