Open jansel opened 4 months ago
Here is another (larger) example of this same error:
import halide as hl
from math import inf, nan
@hl.generator(name="kernel")
class Kernel:
in_ptr0 = hl.InputBuffer(hl.Float(32), 2)
in_ptr1 = hl.InputBuffer(hl.Int(64), 1)
out_ptr0 = hl.OutputBuffer(hl.Float(32), 1)
out_ptr1 = hl.OutputBuffer(hl.Float(32), 1)
out_ptr4 = hl.OutputBuffer(hl.Float(32), 1)
out_ptr5 = hl.OutputBuffer(hl.Float(32), 1)
def generate(g):
in_ptr0 = g.in_ptr0
in_ptr1 = g.in_ptr1
out_ptr0 = g.out_ptr0
out_ptr1 = g.out_ptr1
out_ptr4 = g.out_ptr4
out_ptr5 = g.out_ptr5
h0 = hl.Var("h0")
rdom = hl.RDom([hl.Range(0, 5)])
hr1 = rdom[0]
tmp0 = hl.Func("tmp0")
tmp0[h0] = in_ptr0[
0,
h0,
]
tmp1 = hl.Func("tmp1")
tmp1[h0] = in_ptr0[
1,
h0,
]
tmp2 = hl.Func("tmp2")
tmp2[h0] = (
hl.select((tmp0[h0] > tmp1[h0]) | hl.is_nan(tmp0[h0]), tmp0[h0], tmp1[h0])
if tmp0.type().is_float()
else hl.max(tmp0[h0], tmp1[h0])
)
tmp3 = hl.Func("tmp3")
tmp3[h0] = in_ptr0[
2,
h0,
]
tmp4 = hl.Func("tmp4")
tmp4[h0] = (
hl.select((tmp2[h0] > tmp3[h0]) | hl.is_nan(tmp2[h0]), tmp2[h0], tmp3[h0])
if tmp2.type().is_float()
else hl.max(tmp2[h0], tmp3[h0])
)
tmp5 = hl.Func("tmp5")
tmp5[h0] = in_ptr0[
3,
h0,
]
tmp6 = hl.Func("tmp6")
tmp6[h0] = (
hl.select((tmp4[h0] > tmp5[h0]) | hl.is_nan(tmp4[h0]), tmp4[h0], tmp5[h0])
if tmp4.type().is_float()
else hl.max(tmp4[h0], tmp5[h0])
)
tmp7 = hl.Func("tmp7")
tmp7[h0] = in_ptr0[
4,
h0,
]
tmp8 = hl.Func("tmp8")
tmp8[h0] = (
hl.select((tmp6[h0] > tmp7[h0]) | hl.is_nan(tmp6[h0]), tmp6[h0], tmp7[h0])
if tmp6.type().is_float()
else hl.max(tmp6[h0], tmp7[h0])
)
out_ptr0[h0,] = hl.cast(hl.Float(32), tmp8[h0])
tmp9 = hl.Func("tmp9")
tmp9[h0] = tmp0[h0] - tmp8[h0]
tmp10 = hl.Func("tmp10")
tmp10[h0] = (
hl.fast_exp(hl.cast(hl.Float(32), tmp9[h0]))
if tmp9.type().bits() <= 32
else hl.exp(tmp9[h0])
)
tmp11 = hl.Func("tmp11")
tmp11[h0] = tmp1[h0] - tmp8[h0]
tmp12 = hl.Func("tmp12")
tmp12[h0] = (
hl.fast_exp(hl.cast(hl.Float(32), tmp11[h0]))
if tmp11.type().bits() <= 32
else hl.exp(tmp11[h0])
)
tmp13 = hl.Func("tmp13")
tmp13[h0] = tmp10[h0] + tmp12[h0]
tmp14 = hl.Func("tmp14")
tmp14[h0] = tmp3[h0] - tmp8[h0]
tmp15 = hl.Func("tmp15")
tmp15[h0] = (
hl.fast_exp(hl.cast(hl.Float(32), tmp14[h0]))
if tmp14.type().bits() <= 32
else hl.exp(tmp14[h0])
)
tmp16 = hl.Func("tmp16")
tmp16[h0] = tmp13[h0] + tmp15[h0]
tmp17 = hl.Func("tmp17")
tmp17[h0] = tmp5[h0] - tmp8[h0]
tmp18 = hl.Func("tmp18")
tmp18[h0] = (
hl.fast_exp(hl.cast(hl.Float(32), tmp17[h0]))
if tmp17.type().bits() <= 32
else hl.exp(tmp17[h0])
)
tmp19 = hl.Func("tmp19")
tmp19[h0] = tmp16[h0] + tmp18[h0]
tmp20 = hl.Func("tmp20")
tmp20[h0] = tmp7[h0] - tmp8[h0]
tmp21 = hl.Func("tmp21")
tmp21[h0] = (
hl.fast_exp(hl.cast(hl.Float(32), tmp20[h0]))
if tmp20.type().bits() <= 32
else hl.exp(tmp20[h0])
)
tmp22 = hl.Func("tmp22")
tmp22[h0] = tmp19[h0] + tmp21[h0]
tmp23 = hl.Func("tmp23")
tmp23[h0] = hl.log(tmp22[h0])
out_ptr1[h0,] = hl.cast(hl.Float(32), tmp23[h0])
tmp24 = hl.Func("tmp24")
tmp24[h0] = in_ptr1[h0,]
tmp25 = hl.Func("tmp25")
tmp25 = hl.cast(hl.Int(64), -100)
tmp26 = hl.Func("tmp26")
tmp26[h0] = tmp24[h0] != tmp25
tmp27 = hl.Func("tmp27")
tmp27[h0] = hl.cast(hl.Int(64), tmp26[h0])
tmp28 = hl.Func("tmp28")
tmp28 = hl.sum(rdom, tmp27[hr1])
tmp29 = hl.Func("tmp29")
tmp29 = hl.cast(hl.Int(64), 0)
tmp30 = hl.Func("tmp30")
tmp30[h0] = hl.select(tmp26[h0], tmp24[h0], hl.cast(tmp24.type(), tmp29))
tmp31 = hl.Func("tmp31")
tmp31 = 5
tmp32 = hl.Func("tmp32")
tmp32[h0] = tmp30[h0] + tmp31
tmp33 = hl.Func("tmp33")
tmp33[h0] = tmp30[h0] < 0
tmp34 = hl.Func("tmp34")
tmp34[h0] = hl.select(tmp33[h0], tmp32[h0], hl.cast(tmp32.type(), tmp30[h0]))
tmp35 = hl.Func("tmp35")
tmp35[h0] = hl.cast(hl.Int(32), tmp34[h0])
tmp36 = hl.Func("tmp36")
tmp36[h0] = hl.clamp(tmp35[h0], 0, 4)
tmp37 = hl.Func("tmp37")
tmp37[h0] = in_ptr0[
tmp36[h0],
h0,
]
tmp38 = hl.Func("tmp38")
tmp38[h0] = tmp37[h0] - tmp8[h0]
tmp39 = hl.Func("tmp39")
tmp39[h0] = tmp38[h0] - tmp23[h0]
tmp40 = hl.Func("tmp40")
tmp40[h0] = -tmp39[h0]
tmp41 = hl.Func("tmp41")
tmp41 = hl.cast(hl.Float(32), hl.f64(0.0))
tmp42 = hl.Func("tmp42")
tmp42[h0] = hl.select(tmp26[h0], tmp40[h0], hl.cast(tmp40.type(), tmp41))
tmp43 = hl.Func("tmp43")
tmp43 = hl.sum(rdom, tmp42[hr1])
tmp44 = hl.Func("tmp44")
tmp44 = hl.cast(hl.Float(32), tmp28)
out_ptr4[hl.Var(),] = hl.cast(hl.Float(32), tmp44)
tmp45 = hl.Func("tmp45")
tmp45 = tmp43 / tmp44
out_ptr5[hl.Var(),] = hl.cast(hl.Float(32), tmp45)
assert g.using_autoscheduler()
in_ptr0.dim(0).set_min(0)
in_ptr0.dim(0).set_stride(1)
in_ptr0.dim(0).set_extent(5)
in_ptr0.dim(1).set_min(0)
in_ptr0.dim(1).set_stride(5)
in_ptr0.dim(1).set_extent(5)
in_ptr0.set_estimates([hl.Range(0, 5), hl.Range(0, 5)])
in_ptr1.dim(0).set_min(0)
in_ptr1.dim(0).set_stride(1)
in_ptr1.dim(0).set_extent(5)
in_ptr1.set_estimates([hl.Range(0, 5)])
out_ptr0.set_estimates([hl.Range(0, 5)])
out_ptr1.set_estimates([hl.Range(0, 5)])
out_ptr4.set_estimates([hl.Range(0, 2)])
out_ptr5.set_estimates([hl.Range(0, 2)])
if __name__ == "__main__":
import sys, tempfile
with tempfile.TemporaryDirectory() as out:
sys.argv = [
"repro.py",
"-g",
"kernel",
"-o",
out,
"-f",
"halide_kernel",
"-e",
"static_library,h,schedule",
"-p",
"/home/jansel/conda/envs/pytorch/lib/libautoschedule_anderson2021.so",
"target=host-cuda-cuda_capability_86-user_context-strict_float-no_runtime-no_asserts",
"autoscheduler=Anderson2021",
"autoscheduler.parallelism=82",
]
hl.main()
else:
hl.load_plugin(
"/home/jansel/conda/envs/pytorch/lib/libautoschedule_anderson2021.so"
)
target = hl.Target(
"host-cuda-cuda_capability_86-user_context-strict_float-no_runtime-no_asserts"
)
autoscheduler = hl.AutoschedulerParams("Anderson2021", {"parallelism": 82})
with hl.GeneratorContext(target, autoscheduler):
gen = Kernel()
pipeline = gen._build_pipeline()
# gen.compile_to_callable() does not run the autoscheduler
pipeline.apply_autoscheduler(target, autoscheduler)
kernel = pipeline.compile_to_callable(
[
gen._get_input_parameter(a.name)._to_argument()
for a in gen._get_arginfos()
if a.dir == hl.ArgInfoDirection.Input
],
target,
)
This code is a cleaned up lowering of part of
torch.argmax(torch.adaptive_avg_pool1d(...))
repro.py
(you will need to update the path to
libautoschedule_anderson2021.so
)Output:
The code includes a workaround to #8246 by saying the output size is 2 (when it is actually 1). If I remove that workaround, I get the same error as #8246. I think the workaround is uncovering a different issue, but the two issues are possibly related.