halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.78k stars 1.07k forks source link

LLVM ERROR: Transformation resulted in an invalid module (bfloat16+cuda) #8311

Open jansel opened 1 week ago

jansel commented 1 week ago

Repro:

import halide as hl

@hl.generator(name="kernel")
class Kernel:
    in_ptr0 = hl.InputBuffer(hl.BFloat(16), 2)
    out_ptr1 = hl.OutputBuffer(hl.BFloat(16), 1)

    def generate(g):
        in_ptr0 = g.in_ptr0
        out_ptr1 = g.out_ptr1
        h0 = hl.Var("h0")
        h1 = hl.Var("h1")
        rdom = hl.RDom([hl.Range(0, 49)])
        hr1 = rdom[0]
        tmp0 = hl.Func("tmp0")
        tmp0[h0, h1] = hl.cast(
            hl.Float(32),
            in_ptr0[
                h0,
                h1,
            ],
        )
        tmp2 = hl.Func("tmp2")
        tmp2[h1] = hl.sum(rdom, tmp0[hr1, h1])
        tmp4 = hl.Func("tmp4")
        tmp4[h1] = tmp2[h1] / hl.f32(49.0)
        out_ptr1[h1,] = hl.cast(hl.BFloat(16), tmp4[h1])

        assert g.using_autoscheduler()
        in_ptr0.dim(0).set_min(0)
        in_ptr0.dim(0).set_stride(1)
        in_ptr0.dim(0).set_extent(49)
        in_ptr0.dim(1).set_min(0)
        in_ptr0.dim(1).set_stride(49)
        in_ptr0.dim(1).set_extent(12)
        in_ptr0.set_estimates([hl.Range(0, 49), hl.Range(0, 12)])
        out_ptr1.set_estimates([hl.Range(0, 12)])

if __name__ == "__main__":
    import sys, tempfile

    with tempfile.TemporaryDirectory() as out:
        sys.argv = [
            "repro.py",
            "-g",
            "kernel",
            "-o",
            out,
            "-f",
            "halide_kernel",
            "-e",
            "static_library,h,schedule",
            "-p",
            "/home/jansel/conda/envs/pytorch/lib/libautoschedule_li2018.so",
            "target=host-cuda-cuda_capability_86-user_context-strict_float-no_runtime-no_asserts",
            "autoscheduler=Li2018",
            "autoscheduler.parallelism=82",
        ]
        hl.main()

Output:

$ python repro.py
FPExt only operates on FP
  %t99 = fpext i16 %7 to float
FPExt only operates on FP
  %t993 = fpext i16 %24 to float
FPTrunc only produces an FP
  %6 = fptrunc float %5 to i16
FPTrunc only produces an FP
  %14 = fptrunc float %13 to i16
LLVM ERROR: Transformation resulted in an invalid module

zsh: IOT instruction (core dumped)  python repro.py
jansel commented 1 week ago

Another example from:

def fn(a, b):
    x = a + b
    x_view = x.view(dtype=torch.int16)
    return x_view.mul(2)

Generates:

import halide as hl
from math import inf, nan

@hl.generator(name="kernel")
class Kernel:
    in_ptr0 = hl.InputBuffer(hl.BFloat(16), 1)
    in_ptr1 = hl.InputBuffer(hl.BFloat(16), 1)
    out_ptr0 = hl.OutputBuffer(hl.Int(16), 1)

    def generate(g):
        in_ptr0 = g.in_ptr0
        in_ptr1 = g.in_ptr1
        out_ptr0 = g.out_ptr0
        h0 = hl.Var("h0")
        tmp0 = hl.Func("tmp0")
        tmp0[h0] = hl.cast(hl.Float(32), in_ptr0[h0,])
        tmp1 = hl.Func("tmp1")
        tmp1[h0] = hl.cast(hl.Float(32), in_ptr1[h0,])
        tmp2 = hl.Func("tmp2")
        tmp2[h0] = tmp0[h0] + tmp1[h0]
        tmp3 = hl.Func("tmp3")
        tmp3[h0] = hl.reinterpret(hl.Int(16), hl.cast(hl.BFloat(16), tmp2[h0]))
        tmp4 = hl.Func("tmp4")
        tmp4 = hl.cast(hl.Int(16), 2)
        tmp5 = hl.Func("tmp5")
        tmp5[h0] = tmp3[h0] * tmp4
        out_ptr0[h0,] = hl.cast(hl.Int(16), tmp5[h0])

        assert g.using_autoscheduler()
        in_ptr0.dim(0).set_min(0)
        in_ptr0.dim(0).set_stride(1)
        in_ptr0.dim(0).set_extent(4)
        in_ptr0.set_estimates([hl.Range(0, 4)])
        in_ptr1.dim(0).set_min(0)
        in_ptr1.dim(0).set_stride(1)
        in_ptr1.dim(0).set_extent(4)
        in_ptr1.set_estimates([hl.Range(0, 4)])
        out_ptr0.set_estimates([hl.Range(0, 4)])

if __name__ == "__main__":
    import sys, tempfile

    with tempfile.TemporaryDirectory() as out:
        sys.argv = [
            "repro.py",
            "-g",
            "kernel",
            "-o",
            out,
            "-f",
            "halide_kernel",
            "-e",
            "static_library,h,schedule",
            "-p",
            "/home/jansel/conda/envs/pytorch/lib/libautoschedule_anderson2021.so",
            "target=host-cuda-cuda_capability_86-user_context-strict_float-no_runtime-no_asserts",
            "autoscheduler=Anderson2021",
            "autoscheduler.parallelism=82",
        ]
        hl.main()

Which works on CPU, but on CUDA fails with:

FPExt only operates on FP
  %6 = fpext i16 %5 to float
FPExt only operates on FP
  %9 = fpext i16 %8 to float
FPTrunc only produces an FP
  %11 = fptrunc float %10 to i16
FPExt only operates on FP
  %20 = fpext i16 %19 to float
FPExt only operates on FP
  %24 = fpext i16 %23 to float
FPTrunc only produces an FP
  %26 = fptrunc float %25 to i16
LLVM ERROR: Transformation resulted in an invalid module

zsh: IOT instruction (core dumped)  python repro.py
abadams commented 1 week ago

I think this is because we're using i16s to represent bfloat16s and something in codegen isn't handling it right. LLVM has bfloat as a native type now, so we can probably just switch to that.

abadams commented 1 week ago

8324