halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.77k stars 1.07k forks source link

Signed integer overflow occurred during constant-folding caused by autoschedulers #8277

Open jansel opened 3 weeks ago

jansel commented 3 weeks ago

This issue is the same error as #8227, but it is not fixed by #8234.

I am running https://github.com/halide/Halide/commit/340136fec6d3ebc73e7a19eba1663e9b0ba8ab2d -- which includes #8234. Unlike #8227 this one only triggers when an autoscheduler is used.

import halide as hl

@hl.generator(name="kernel")
class Kernel:
    in_ptr0 = hl.InputBuffer(hl.Float(32), 1)
    out_ptr1 = hl.OutputBuffer(hl.Float(32), 1)

    def generate(g):
        in_ptr0 = g.in_ptr0
        out_ptr1 = g.out_ptr1
        xindex = hl.Var("xindex")
        x1 = (xindex // 478) % 320
        x0 = xindex % 478
        x2 = xindex // 152960
        x4 = xindex
        tmp0 = hl.Func("tmp0")
        tmp0[xindex] = x1
        tmp1 = hl.Func("tmp1")
        tmp1[xindex] = hl.cast(hl.Float(32), tmp0[xindex])
        tmp2 = hl.cast(hl.Float(32), hl.f64(0.49843260188087773))
        tmp3 = hl.Func("tmp3")
        tmp3[xindex] = tmp1[xindex] * tmp2
        tmp4 = hl.cast(hl.Float(32), hl.f64(0.0))
        tmp5 = hl.Func("tmp5")
        tmp5[xindex] = hl.max(tmp3[xindex], tmp4)
        tmp6 = hl.Func("tmp6")
        tmp6[xindex] = hl.cast(hl.Int(32), tmp5[xindex])
        tmp7 = hl.cast(hl.Int(64), 1)
        tmp8 = hl.Func("tmp8")
        tmp8[xindex] = tmp6[xindex] + tmp7
        tmp9 = hl.cast(hl.Int(64), 159)
        tmp10 = hl.Func("tmp10")
        tmp10[xindex] = hl.min(tmp8[xindex], tmp9)
        tmp11 = hl.Func("tmp11")
        tmp11[xindex] = hl.cast(hl.Int(32), tmp10[xindex])
        tmp12 = hl.Func("tmp12")
        tmp12[xindex] = x0
        tmp13 = hl.Func("tmp13")
        tmp13[xindex] = hl.cast(hl.Float(32), tmp12[xindex])
        tmp14 = hl.cast(hl.Float(32), hl.f64(0.4989517819706499))
        tmp15 = hl.Func("tmp15")
        tmp15[xindex] = tmp13[xindex] * tmp14
        tmp16 = hl.Func("tmp16")
        tmp16[xindex] = hl.max(tmp15[xindex], tmp4)
        tmp17 = hl.Func("tmp17")
        tmp17[xindex] = hl.cast(hl.Int(32), tmp16[xindex])
        tmp18 = hl.Func("tmp18")
        tmp18[xindex] = hl.cast(hl.Int(32), tmp17[xindex])
        tmp19 = hl.Func("tmp19")
        tmp19[xindex] = in_ptr0[
            hl.clamp(
                x2 + (128 * (tmp18[xindex])) + (30592 * (tmp11[xindex])), 0, 4894719
            )
        ]
        tmp20 = hl.Func("tmp20")
        tmp20[xindex] = tmp17[xindex] + tmp7
        tmp21 = hl.cast(hl.Int(64), 238)
        tmp22 = hl.Func("tmp22")
        tmp22[xindex] = hl.min(tmp20[xindex], tmp21)
        tmp23 = hl.Func("tmp23")
        tmp23[xindex] = hl.cast(hl.Int(32), tmp22[xindex])
        tmp24 = hl.Func("tmp24")
        tmp24[xindex] = in_ptr0[
            hl.clamp(
                x2 + (128 * (tmp23[xindex])) + (30592 * (tmp11[xindex])), 0, 4894719
            )
        ]
        tmp25 = hl.Func("tmp25")
        tmp25[xindex] = tmp24[xindex] - tmp19[xindex]
        tmp26 = hl.Func("tmp26")
        tmp26[xindex] = hl.cast(hl.Float(32), tmp17[xindex])
        tmp27 = hl.Func("tmp27")
        tmp27[xindex] = tmp16[xindex] - tmp26[xindex]
        tmp28 = hl.Func("tmp28")
        tmp28[xindex] = hl.max(tmp27[xindex], tmp4)
        tmp29 = hl.cast(hl.Float(32), hl.f64(1.0))
        tmp30 = hl.Func("tmp30")
        tmp30[xindex] = hl.min(tmp28[xindex], tmp29)
        tmp31 = hl.Func("tmp31")
        tmp31[xindex] = tmp25[xindex] * tmp30[xindex]
        tmp32 = hl.Func("tmp32")
        tmp32[xindex] = tmp19[xindex] + tmp31[xindex]
        tmp33 = hl.Func("tmp33")
        tmp33[xindex] = hl.cast(hl.Int(32), tmp6[xindex])
        tmp34 = hl.Func("tmp34")
        tmp34[xindex] = in_ptr0[
            hl.clamp(
                x2 + (128 * (tmp18[xindex])) + (30592 * (tmp33[xindex])), 0, 4894719
            )
        ]
        tmp35 = hl.Func("tmp35")
        tmp35[xindex] = in_ptr0[
            hl.clamp(
                x2 + (128 * (tmp23[xindex])) + (30592 * (tmp33[xindex])), 0, 4894719
            )
        ]
        tmp36 = hl.Func("tmp36")
        tmp36[xindex] = tmp35[xindex] - tmp34[xindex]
        tmp37 = hl.Func("tmp37")
        tmp37[xindex] = tmp36[xindex] * tmp30[xindex]
        tmp38 = hl.Func("tmp38")
        tmp38[xindex] = tmp34[xindex] + tmp37[xindex]
        tmp39 = hl.Func("tmp39")
        tmp39[xindex] = tmp32[xindex] - tmp38[xindex]
        tmp40 = hl.Func("tmp40")
        tmp40[xindex] = hl.cast(hl.Float(32), tmp6[xindex])
        tmp41 = hl.Func("tmp41")
        tmp41[xindex] = tmp5[xindex] - tmp40[xindex]
        tmp42 = hl.Func("tmp42")
        tmp42[xindex] = hl.max(tmp41[xindex], tmp4)
        tmp43 = hl.Func("tmp43")
        tmp43[xindex] = hl.min(tmp42[xindex], tmp29)
        tmp44 = hl.Func("tmp44")
        tmp44[xindex] = tmp39[xindex] * tmp43[xindex]
        tmp45 = hl.Func("tmp45")
        tmp45[xindex] = tmp38[xindex] + tmp44[xindex]
        out_ptr1[x4] = hl.cast(out_ptr1.type(), tmp45[xindex])

        in_ptr0.set_estimates([hl.Range(0, 4894720)])
        out_ptr1.set_estimates([hl.Range(0, 19578880)])

hl.load_plugin("/home/jansel/conda/envs/pytorch/lib/libautoschedule_adams2019.so")
hl.load_plugin("/home/jansel/conda/envs/pytorch/lib/libautoschedule_li2018.so")
hl.load_plugin("/home/jansel/conda/envs/pytorch/lib/libautoschedule_mullapudi2016.so")
target = hl.Target("host-strict_float-no_runtime-no_asserts")
autoscheduler0 = hl.AutoschedulerParams("", {})
autoscheduler1 = hl.AutoschedulerParams("Mullapudi2016", {"parallelism": 8})
autoscheduler2 = hl.AutoschedulerParams("Li2018", {"parallelism": 8})
autoscheduler3 = hl.AutoschedulerParams("Adams2019", {"parallelism": 8})

for autoscheduler in (autoscheduler0, autoscheduler1, autoscheduler2, autoscheduler3):
    print("Running autoscheduler", autoscheduler.name)
    with hl.GeneratorContext(target, autoscheduler):
        gen = Kernel()
        # gen.compile_to_callable() does not run the autoscheduler
        pipeline = gen._build_pipeline()
        if autoscheduler.name:
            pipeline.apply_autoscheduler(target, autoscheduler)
        try:
            kernel = pipeline.compile_to_callable(
                [
                    gen._get_input_parameter(a.name)._to_argument()
                    for a in gen._get_arginfos()
                    if a.dir == hl.ArgInfoDirection.Input
                ],
                target,
            )
            print("OK\n")
        except Exception as e:
            print("Failed with", e)

Output:

$ python repro.py 
Running autoscheduler 
OK

Running autoscheduler Mullapudi2016
Failed with Error: Signed integer overflow occurred during constant-folding. Signed integer overflow for int32 and int64 is undefined behavior in Halide.

Running autoscheduler Li2018
Failed with Error: Signed integer overflow occurred during constant-folding. Signed integer overflow for int32 and int64 is undefined behavior in Halide.

Running autoscheduler Adams2019
Failed with Error: Signed integer overflow occurred during constant-folding. Signed integer overflow for int32 and int64 is undefined behavior in Halide.
abadams commented 3 weeks ago

I've investigated, and I believe this can't really be fixed without making existing code worse unless we fix #3245 entirely. The specific issue is:

        x2 = xindex // 152960

Given #3245, this sort of code in indexing expressions is going to cause problems. Halide tries to solve equations involving these constants to do things like loop partitioning, and in this case 152960 gets multiplied by 30592 at some point, which then overflows so the compiler essentially panics. A workaround is to use Params instead of int literals for your tensor shapes to avoid all these large constants, but it's an ugly workaround that doesn't fix #3245 and doesn't solve the other big problem here:

This 1D indexing is terrible for performance. It adds expensive divs and mods, can break vectorization (because dense loads become vector gathers, makes the autoscheduler useless (because there's a single specific scheduling idiom you need to apply to remove the div and mod), makes large_buffers useless (because large_buffers means each dim is still 31-bit, but the product of the dims may be 63-bit), and generally makes a mess in the IR. I strongly recommend not using 1D inputs and outputs.

You said earlier that you're doing it because there are views that pytorch supports that can't be expressed with affine indexing. Can you elaborate?

jansel commented 3 weeks ago

I'll write a pass to convert to dimensions where possible. It is possible often, but not always.

A good example to illustrate the issue is torch.as_strided(), which take the memory of an existing tensor and reinterprets it as a new tensor with a provided list of sizes/strides, possibly permuted in any order with respect to memory order or with a different number of dimensions.

There are lots of other view ops (view/reshape/permute/transpose/slice/index/etc), but all of them can be mapped to as_strided.

When using dimensions, it was hard to map as_strided to Halide. Maybe there is some op I was missing. A single sizes/strides per input can be represented, but not multiple or changing them mid-program.