Open jansel opened 1 week ago
Another example from:
def fn(a, b):
x = a + b
x_view = x.view(dtype=torch.int16)
return x_view.mul(2)
Generates:
import halide as hl
from math import inf, nan
@hl.generator(name="kernel")
class Kernel:
in_ptr0 = hl.InputBuffer(hl.BFloat(16), 1)
in_ptr1 = hl.InputBuffer(hl.BFloat(16), 1)
out_ptr0 = hl.OutputBuffer(hl.Int(16), 1)
def generate(g):
in_ptr0 = g.in_ptr0
in_ptr1 = g.in_ptr1
out_ptr0 = g.out_ptr0
h0 = hl.Var("h0")
tmp0 = hl.Func("tmp0")
tmp0[h0] = hl.cast(hl.Float(32), in_ptr0[h0,])
tmp1 = hl.Func("tmp1")
tmp1[h0] = hl.cast(hl.Float(32), in_ptr1[h0,])
tmp2 = hl.Func("tmp2")
tmp2[h0] = tmp0[h0] + tmp1[h0]
tmp3 = hl.Func("tmp3")
tmp3[h0] = hl.reinterpret(hl.Int(16), hl.cast(hl.BFloat(16), tmp2[h0]))
tmp4 = hl.Func("tmp4")
tmp4 = hl.cast(hl.Int(16), 2)
tmp5 = hl.Func("tmp5")
tmp5[h0] = tmp3[h0] * tmp4
out_ptr0[h0,] = hl.cast(hl.Int(16), tmp5[h0])
assert g.using_autoscheduler()
in_ptr0.dim(0).set_min(0)
in_ptr0.dim(0).set_stride(1)
in_ptr0.dim(0).set_extent(4)
in_ptr0.set_estimates([hl.Range(0, 4)])
in_ptr1.dim(0).set_min(0)
in_ptr1.dim(0).set_stride(1)
in_ptr1.dim(0).set_extent(4)
in_ptr1.set_estimates([hl.Range(0, 4)])
out_ptr0.set_estimates([hl.Range(0, 4)])
if __name__ == "__main__":
import sys, tempfile
with tempfile.TemporaryDirectory() as out:
sys.argv = [
"repro.py",
"-g",
"kernel",
"-o",
out,
"-f",
"halide_kernel",
"-e",
"static_library,h,schedule",
"-p",
"/home/jansel/conda/envs/pytorch/lib/libautoschedule_anderson2021.so",
"target=host-cuda-cuda_capability_86-user_context-strict_float-no_runtime-no_asserts",
"autoscheduler=Anderson2021",
"autoscheduler.parallelism=82",
]
hl.main()
Which works on CPU, but on CUDA fails with:
FPExt only operates on FP
%6 = fpext i16 %5 to float
FPExt only operates on FP
%9 = fpext i16 %8 to float
FPTrunc only produces an FP
%11 = fptrunc float %10 to i16
FPExt only operates on FP
%20 = fpext i16 %19 to float
FPExt only operates on FP
%24 = fpext i16 %23 to float
FPTrunc only produces an FP
%26 = fptrunc float %25 to i16
LLVM ERROR: Transformation resulted in an invalid module
zsh: IOT instruction (core dumped) python repro.py
I think this is because we're using i16s to represent bfloat16s and something in codegen isn't handling it right. LLVM has bfloat as a native type now, so we can probably just switch to that.
Repro:
Output: