intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
143 stars 44 forks source link

[PyTorch Upstream] Triton crash when compile torch.nn.functional.interpolate.bicubic #829

Closed etaf closed 7 months ago

etaf commented 7 months ago

We got error in Inductor UT when use triton compile torch.nn.functional.interpolate.bicubic:

L0 build module failed. Log:
Fatal Python error: Segmentation fault

To reproduce this error: The triton version pinned in public pytorch is :b8c64f64c18d8cac598b3adb355c21e7439c21de

  1. clone pytorch from https://github.com/pytorch/pytorch.git
  2. build with export USE_XPU=1
  3. Before run case, export PYTORCH_ENABLE_XPU_FALLBACK=1
  4. run the following case:
    
    import torch

def fn(a): return torch.nn.functional.interpolate(a, size=(3, 3), mode='bicubic')

myfn = torch.compile(fn, backend="inductor") a = torch.randn(2, 3, 4, 4, dtype=torch.float32, device="xpu", requires_grad=True) print(myfn(a))


The triton kernel Inductor generated is:

from ctypes import c_void_p, c_long import torch import math import random import os import tempfile from math import inf, nan from torch._inductor.hooks import run_intermediate_hooks from torch._inductor.utils import maybe_profile from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty_strided from torch._inductor.codecache import AsyncCompile from torch._inductor.select_algorithm import extern_kernels from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten inductor_ops = torch.ops.inductor _quantized = torch.ops._quantized assert_size_stride = torch._C._dynamo.guards.assert_size_stride empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda alloc_from_pool = torch.ops.inductor._alloc_from_pool reinterpret_tensor = torch.ops.inductor._reinterpret_tensor async_compile = AsyncCompile()

kernel path: /tmp/torchinductor_xinanlin/5b/c5b3dekb2ejuofxx7nrwrx6wusxc5gcrzecpz6usndcthpn4u37m.py

Source Nodes: [interpolate], Original ATen: [aten._to_copy, aten.clamp, aten.floor, aten.sub]

interpolate => clamp_max_2, clamp_min_2, convert_element_type_3, floor_1, sub_4

triton_poi_fused__to_copy_clamp_floor_sub_0 = asynccompile.triton('triton', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics from torch._inductor.ir import ReductionHint, TileHint from torch._inductor.triton_helpers import libdevice, math as tl_math from torch._inductor.triton_heuristics import AutotuneHint from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise( size_hints=[4], filename=file, triton_meta={'signature': {0: 'i64', 1: 'i32'}, 'device': 0, 'device_type': 'xpu', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_clamp_floor_sub_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '35c35aad0fdd2b464ff093e3eeb1b3edcf6bfa6280af00b4abe43e27387608db'}, min_elem_perthread=0 ) @triton.jit def triton(out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 3 xoffset = tl.program_id(0) XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = x0 tmp1 = tmp0.to(tl.float32) tmp2 = 0.5 tmp3 = tmp1 + tmp2 tmp4 = 1.3333333333333333 tmp5 = tmp3 * tmp4 tmp6 = tmp5 - tmp2 tmp7 = libdevice.floor(tmp6) tmp8 = tmp7.to(tl.int32) tmp9 = tl.full([1], 1, tl.int64) tmp10 = tmp8 - tmp9 tmp11 = tl.full([1], 0, tl.int64) tmp12 = triton_helpers.maximum(tmp10, tmp11) tmp13 = tl.full([1], 3, tl.int64) tmp14 = triton_helpers.minimum(tmp12, tmp13) tl.store(out_ptr0 + (x0), tmp14, xmask) ''', device_str='xpu')

import triton import triton.language as tl from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph from torch._C import _xpu_getCurrentRawStream as get_raw_stream

kernel path: /tmp/torchinductor_xinanlin/j7/cj7spkiwbrrkrvcqrsciz5527gq3ipqtgsq3nrplcjpi7tn7tee4.py

Source Nodes: [interpolate], Original ATen: [aten._to_copy, aten.add, aten.arange, aten.clamp, aten.floor, aten.mul, aten.sub]

interpolate => add, clamp_max_5, clamp_min_5, convert_element_type, convert_element_type_2, floor, iota, mul, sub

triton_poi_fused__to_copy_add_arange_clamp_floor_mul_sub_1 = asynccompile.triton('triton', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics from torch._inductor.ir import ReductionHint, TileHint from torch._inductor.triton_helpers import libdevice, math as tl_math from torch._inductor.triton_heuristics import AutotuneHint from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise( size_hints=[4], filename=file, triton_meta={'signature': {0: 'i64', 1: 'i32'}, 'device': 0, 'device_type': 'xpu', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_arange_clamp_floor_mul_sub_1', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '35c35aad0fdd2b464ff093e3eeb1b3edcf6bfa6280af00b4abe43e27387608db'}, min_elem_perthread=0 ) @triton.jit def triton(out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 3 xoffset = tl.program_id(0) XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = x0 tmp1 = tmp0.to(tl.float32) tmp2 = 0.5 tmp3 = tmp1 + tmp2 tmp4 = 1.3333333333333333 tmp5 = tmp3 * tmp4 tmp6 = tmp5 - tmp2 tmp7 = libdevice.floor(tmp6) tmp8 = tmp7.to(tl.int32) tmp9 = tl.full([1], 0, tl.int64) tmp10 = triton_helpers.maximum(tmp8, tmp9) tmp11 = tl.full([1], 3, tl.int64) tmp12 = triton_helpers.minimum(tmp10, tmp11) tl.store(out_ptr0 + (x0), tmp12, xmask) ''', device_str='xpu')

kernel path: /tmp/torchinductor_xinanlin/ty/ctyl4rtlhgoyusglk3bgx23ru3l5s4zqxyfuv3htlgxsymelj7ny.py

Source Nodes: [interpolate], Original ATen: [aten._to_copy, aten.add, aten.arange, aten.clamp, aten.floor, aten.mul, aten.sub]

interpolate => add, add_4, clamp_max_7, clamp_min_7, convert_element_type, convert_element_type_2, floor, iota, mul, sub

triton_poi_fused__to_copy_add_arange_clamp_floor_mul_sub_2 = asynccompile.triton('triton', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics from torch._inductor.ir import ReductionHint, TileHint from torch._inductor.triton_helpers import libdevice, math as tl_math from torch._inductor.triton_heuristics import AutotuneHint from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise( size_hints=[4], filename=file, triton_meta={'signature': {0: 'i64', 1: 'i32'}, 'device': 0, 'device_type': 'xpu', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_arange_clamp_floor_mul_sub_2', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '35c35aad0fdd2b464ff093e3eeb1b3edcf6bfa6280af00b4abe43e27387608db'}, min_elem_perthread=0 ) @triton.jit def triton(out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 3 xoffset = tl.program_id(0) XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = x0 tmp1 = tmp0.to(tl.float32) tmp2 = 0.5 tmp3 = tmp1 + tmp2 tmp4 = 1.3333333333333333 tmp5 = tmp3 * tmp4 tmp6 = tmp5 - tmp2 tmp7 = libdevice.floor(tmp6) tmp8 = tmp7.to(tl.int32) tmp9 = tl.full([1], 1, tl.int64) tmp10 = tmp8 + tmp9 tmp11 = tl.full([1], 0, tl.int64) tmp12 = triton_helpers.maximum(tmp10, tmp11) tmp13 = tl.full([1], 3, tl.int64) tmp14 = triton_helpers.minimum(tmp12, tmp13) tl.store(out_ptr0 + (x0), tmp14, xmask) ''', device_str='xpu')

kernel path: /tmp/torchinductor_xinanlin/xz/cxzitmyy4mlangsorvkmhwt55qgrhjvc4jl7jna2boprbwe3tuik.py

Source Nodes: [interpolate], Original ATen: [aten._to_copy, aten.add, aten.arange, aten.clamp, aten.floor, aten.mul, aten.sub]

interpolate => add, add_5, clamp_max_9, clamp_min_9, convert_element_type, convert_element_type_2, floor, iota, mul, sub

triton_poi_fused__to_copy_add_arange_clamp_floor_mul_sub_3 = asynccompile.triton('triton', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics from torch._inductor.ir import ReductionHint, TileHint from torch._inductor.triton_helpers import libdevice, math as tl_math from torch._inductor.triton_heuristics import AutotuneHint from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise( size_hints=[4], filename=file, triton_meta={'signature': {0: 'i64', 1: 'i32'}, 'device': 0, 'device_type': 'xpu', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_arange_clamp_floor_mul_sub_3', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '35c35aad0fdd2b464ff093e3eeb1b3edcf6bfa6280af00b4abe43e27387608db'}, min_elem_perthread=0 ) @triton.jit def triton(out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 3 xoffset = tl.program_id(0) XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = x0 tmp1 = tmp0.to(tl.float32) tmp2 = 0.5 tmp3 = tmp1 + tmp2 tmp4 = 1.3333333333333333 tmp5 = tmp3 * tmp4 tmp6 = tmp5 - tmp2 tmp7 = libdevice.floor(tmp6) tmp8 = tmp7.to(tl.int32) tmp9 = tl.full([1], 2, tl.int64) tmp10 = tmp8 + tmp9 tmp11 = tl.full([1], 0, tl.int64) tmp12 = triton_helpers.maximum(tmp10, tmp11) tmp13 = tl.full([1], 3, tl.int64) tmp14 = triton_helpers.minimum(tmp12, tmp13) tl.store(out_ptr0 + (x0), tmp14, xmask) ''', device_str='xpu')

kernel path: /tmp/torchinductor_xinanlin/3i/c3iyyt5rbl5245hmuwfaygtjvbankow7xzbrtsuijllwq4h5jxcy.py

Source Nodes: [interpolate], Original ATen: [aten._to_copy, aten._unsafe_index, aten.add, aten.arange, aten.clamp, aten.floor, aten.mul, aten.rsub, aten.sub]

interpolate => _unsafe_index, _unsafe_index_1, _unsafe_index_10, _unsafe_index_11, _unsafe_index_12, _unsafe_index_13, _unsafe_index_14, _unsafe_index_15, _unsafe_index_2, _unsafe_index_3, _unsafe_index_4, _unsafe_index_5, _unsafe_index_6, _unsafe_index_7, _unsafe_index_8, _unsafe_index_9, add, add_10, add_11, add_12, add_13, add_14, add_15, add_16, add_17, add_18, add_19, add_20, add_21, add_22, add_23, add_24, add_25, add_26, add_27, add_28, add_29, add_30, add_6, add_7, add_8, add_9, clamp_max, clamp_max_1, clamp_min, clamp_min_1, convert_element_type, floor, floor_1, iota, mul, mul_10, mul_11, mul_12, mul_13, mul_14, mul_15, mul_16, mul_17, mul_18, mul_19, mul_2, mul_20, mul_21, mul_22, mul_23, mul_24, mul_25, mul_26, mul_27, mul_28, mul_29, mul_3, mul_30, mul_31, mul_32, mul_33, mul_34, mul_35, mul_36, mul_37, mul_38, mul_39, mul_4, mul_40, mul_41, mul_42, mul_43, mul_44, mul_45, mul_5, mul_6, mul_7, mul_8, mul_9, sub, sub_10, sub_11, sub_12, sub_13, sub_14, sub_15, sub_16, sub_17, sub_18, sub_19, sub_2, sub_20, sub_21, sub_3, sub_6, sub_7, sub_8, sub_9

triton_poi_fused__to_copy__unsafe_index_add_arange_clamp_floor_mul_rsub_sub_4 = asynccompile.triton('triton', ''' import triton import triton.language as tl from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics from torch._inductor.ir import ReductionHint, TileHint from torch._inductor.triton_helpers import libdevice, math as tl_math from torch._inductor.triton_heuristics import AutotuneHint from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise( size_hints=[64], filename=file, triton_meta={'signature': {0: 'fp32', 1: 'i64', 2: 'i64', 3: 'fp32', 4: 'i64', 5: 'i64', 6: 'i64', 7: 'i64', 8: 'i64', 9: 'i64', 10: 'i32'}, 'device': 0, 'device_type': 'xpu', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]}, inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy__unsafe_index_add_arange_clamp_floor_mul_rsub_sub_4', 'mutated_arg_names': ['in_out_ptr2'], 'no_x_dim': False, 'backend_hash': '35c35aad0fdd2b464ff093e3eeb1b3edcf6bfa6280af00b4abe43e27387608db'}, min_elem_perthread=0 ) @triton.jit def triton(in_out_ptr2, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, xnumel, XBLOCK : tl.constexpr): xnumel = 54 xoffset = tl.program_id(0) XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x1 = (xindex // 3) % 3 x0 = xindex % 3 x2 = (xindex // 9) x4 = xindex tmp0 = tl.load(in_ptr0 + (x1), xmask, eviction_policy='evict_last') tmp4 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last') tmp34 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last') tmp48 = tl.load(in_ptr4 + (x1), xmask, eviction_policy='evict_last') tmp57 = tl.load(in_ptr5 + (x1), xmask, eviction_policy='evict_last') tmp66 = tl.load(in_ptr6 + (x1), xmask, eviction_policy='evict_last') tmp75 = tl.load(in_ptr7 + (x0), xmask, eviction_policy='evict_last') tmp88 = tl.load(in_ptr8 + (x0), xmask, eviction_policy='evict_last') tmp1 = tmp0 + 4 tmp2 = tmp0 < 0 tmp3 = tl.where(tmp2, tmp1, tmp0) tmp5 = tmp4 + 4 tmp6 = tmp4 < 0 tmp7 = tl.where(tmp6, tmp5, tmp4) tmp8 = tl.load(in_ptr2 + (tmp7 + (4tmp3) + (16x2)), xmask, eviction_policy='evict_last') tmp9 = x0 tmp10 = tmp9.to(tl.float32) tmp11 = 0.5 tmp12 = tmp10 + tmp11 tmp13 = 1.3333333333333333 tmp14 = tmp12 tmp13 tmp15 = tmp14 - tmp11 tmp16 = libdevice.floor(tmp15) tmp17 = tmp15 - tmp16 tmp18 = 0.0 tmp19 = triton_helpers.maximum(tmp17, tmp18) tmp20 = 1.0 tmp21 = triton_helpers.minimum(tmp19, tmp20) tmp22 = tmp21 + tmp20 tmp23 = -0.75 tmp24 = tmp22 tmp23 tmp25 = -3.75 tmp26 = tmp24 - tmp25 tmp27 = tmp26 tmp22 tmp28 = -6.0 tmp29 = tmp27 + tmp28 tmp30 = tmp29 tmp22 tmp31 = -3.0 tmp32 = tmp30 - tmp31 tmp33 = tmp8 tmp32 tmp35 = tmp34 + 4 tmp36 = tmp34 < 0 tmp37 = tl.where(tmp36, tmp35, tmp34) tmp38 = tl.load(in_ptr2 + (tmp37 + (4tmp3) + (16x2)), xmask, eviction_policy='evict_last') tmp39 = 1.25 tmp40 = tmp21 tmp39 tmp41 = 2.25 tmp42 = tmp40 - tmp41 tmp43 = tmp42 tmp21 tmp44 = tmp43 tmp21 tmp45 = tmp44 + tmp20 tmp46 = tmp38 tmp45 tmp47 = tmp33 + tmp46 tmp49 = tmp48 + 4 tmp50 = tmp48 < 0 tmp51 = tl.where(tmp50, tmp49, tmp48) tmp52 = tl.load(in_ptr2 + (tmp7 + (4tmp51) + (16x2)), xmask, eviction_policy='evict_last') tmp53 = tmp52 tmp32 tmp54 = tl.load(in_ptr2 + (tmp37 + (4tmp51) + (16x2)), xmask, eviction_policy='evict_last') tmp55 = tmp54 tmp45 tmp56 = tmp53 + tmp55 tmp58 = tmp57 + 4 tmp59 = tmp57 < 0 tmp60 = tl.where(tmp59, tmp58, tmp57) tmp61 = tl.load(in_ptr2 + (tmp7 + (4tmp60) + (16x2)), xmask, eviction_policy='evict_last') tmp62 = tmp61 tmp32 tmp63 = tl.load(in_ptr2 + (tmp37 + (4tmp60) + (16x2)), xmask, eviction_policy='evict_last') tmp64 = tmp63 tmp45 tmp65 = tmp62 + tmp64 tmp67 = tmp66 + 4 tmp68 = tmp66 < 0 tmp69 = tl.where(tmp68, tmp67, tmp66) tmp70 = tl.load(in_ptr2 + (tmp7 + (4tmp69) + (16x2)), xmask, eviction_policy='evict_last') tmp71 = tmp70 tmp32 tmp72 = tl.load(in_ptr2 + (tmp37 + (4tmp69) + (16x2)), xmask, eviction_policy='evict_last') tmp73 = tmp72 tmp45 tmp74 = tmp71 + tmp73 tmp76 = tmp75 + 4 tmp77 = tmp75 < 0 tmp78 = tl.where(tmp77, tmp76, tmp75) tmp79 = tl.load(in_ptr2 + (tmp78 + (4tmp3) + (16x2)), xmask, eviction_policy='evict_last') tmp80 = tmp20 - tmp21 tmp81 = tmp80 tmp39 tmp82 = tmp81 - tmp41 tmp83 = tmp82 tmp80 tmp84 = tmp83 tmp80 tmp85 = tmp84 + tmp20 tmp86 = tmp79 tmp85 tmp87 = tmp47 + tmp86 tmp89 = tmp88 + 4 tmp90 = tmp88 < 0 tmp91 = tl.where(tmp90, tmp89, tmp88) tmp92 = tl.load(in_ptr2 + (tmp91 + (4tmp3) + (16x2)), xmask, eviction_policy='evict_last') tmp93 = 2.0 tmp94 = tmp93 - tmp21 tmp95 = tmp94 tmp23 tmp96 = tmp95 - tmp25 tmp97 = tmp96 tmp94 tmp98 = tmp97 + tmp28 tmp99 = tmp98 tmp94 tmp100 = tmp99 - tmp31 tmp101 = tmp92 tmp100 tmp102 = tmp87 + tmp101 tmp103 = tl.load(in_ptr2 + (tmp78 + (4tmp51) + (16x2)), xmask, eviction_policy='evict_last') tmp104 = tmp103 tmp85 tmp105 = tmp56 + tmp104 tmp106 = tl.load(in_ptr2 + (tmp91 + (4tmp51) + (16x2)), xmask, eviction_policy='evict_last') tmp107 = tmp106 tmp100 tmp108 = tmp105 + tmp107 tmp109 = tl.load(in_ptr2 + (tmp78 + (4tmp60) + (16x2)), xmask, eviction_policy='evict_last') tmp110 = tmp109 tmp85 tmp111 = tmp65 + tmp110 tmp112 = tl.load(in_ptr2 + (tmp91 + (4tmp60) + (16x2)), xmask, eviction_policy='evict_last') tmp113 = tmp112 tmp100 tmp114 = tmp111 + tmp113 tmp115 = tl.load(in_ptr2 + (tmp78 + (4tmp69) + (16x2)), xmask, eviction_policy='evict_last') tmp116 = tmp115 tmp85 tmp117 = tmp74 + tmp116 tmp118 = tl.load(in_ptr2 + (tmp91 + (4tmp69) + (16x2)), xmask, eviction_policy='evict_last') tmp119 = tmp118 tmp100 tmp120 = tmp117 + tmp119 tmp121 = x1 tmp122 = tmp121.to(tl.float32) tmp123 = tmp122 + tmp11 tmp124 = tmp123 tmp13 tmp125 = tmp124 - tmp11 tmp126 = libdevice.floor(tmp125) tmp127 = tmp125 - tmp126 tmp128 = triton_helpers.maximum(tmp127, tmp18) tmp129 = triton_helpers.minimum(tmp128, tmp20) tmp130 = tmp129 + tmp20 tmp131 = tmp130 tmp23 tmp132 = tmp131 - tmp25 tmp133 = tmp132 tmp130 tmp134 = tmp133 + tmp28 tmp135 = tmp134 tmp130 tmp136 = tmp135 - tmp31 tmp137 = tmp102 tmp136 tmp138 = tmp129 tmp39 tmp139 = tmp138 - tmp41 tmp140 = tmp139 tmp129 tmp141 = tmp140 tmp129 tmp142 = tmp141 + tmp20 tmp143 = tmp108 tmp142 tmp144 = tmp137 + tmp143 tmp145 = tmp20 - tmp129 tmp146 = tmp145 tmp39 tmp147 = tmp146 - tmp41 tmp148 = tmp147 tmp145 tmp149 = tmp148 tmp145 tmp150 = tmp149 + tmp20 tmp151 = tmp114 tmp150 tmp152 = tmp144 + tmp151 tmp153 = tmp93 - tmp129 tmp154 = tmp153 tmp23 tmp155 = tmp154 - tmp25 tmp156 = tmp155 tmp153 tmp157 = tmp156 + tmp28 tmp158 = tmp157 tmp153 tmp159 = tmp158 - tmp31 tmp160 = tmp120 * tmp159 tmp161 = tmp152 + tmp160 tl.store(in_out_ptr2 + (x4), tmp161, xmask) ''', device_str='xpu')

async_compile.wait(globals()) del async_compile

def call(args): primals_1, = args args.clear() assert_size_stride(primals_1, (2, 3, 4, 4), (48, 16, 4, 1)) with torch.xpu._DeviceGuard(0): torch.xpu.set_device(0) buf0 = empty_strided((3, 1), (1, 1), device='xpu', dtype=torch.int64)

Source Nodes: [interpolate], Original ATen: [aten._to_copy, aten.clamp, aten.floor, aten.sub]

    stream0 = get_raw_stream(0)
    triton_poi_fused__to_copy_clamp_floor_sub_0.run(buf0, 3, grid=grid(3), stream=stream0)
    buf1 = empty_strided((3, ), (1, ), device='xpu', dtype=torch.int64)
    # Source Nodes: [interpolate], Original ATen: [aten._to_copy, aten.add, aten.arange, aten.clamp, aten.floor, aten.mul, aten.sub]
    triton_poi_fused__to_copy_clamp_floor_sub_0.run(buf1, 3, grid=grid(3), stream=stream0)
    buf2 = empty_strided((3, ), (1, ), device='xpu', dtype=torch.int64)
    # Source Nodes: [interpolate], Original ATen: [aten._to_copy, aten.add, aten.arange, aten.clamp, aten.floor, aten.mul, aten.sub]
    triton_poi_fused__to_copy_add_arange_clamp_floor_mul_sub_1.run(buf2, 3, grid=grid(3), stream=stream0)
    buf3 = empty_strided((3, ), (1, ), device='xpu', dtype=torch.int64)
    # Source Nodes: [interpolate], Original ATen: [aten._to_copy, aten.add, aten.arange, aten.clamp, aten.floor, aten.mul, aten.sub]
    triton_poi_fused__to_copy_add_arange_clamp_floor_mul_sub_2.run(buf3, 3, grid=grid(3), stream=stream0)
    buf4 = empty_strided((3, ), (1, ), device='xpu', dtype=torch.int64)
    # Source Nodes: [interpolate], Original ATen: [aten._to_copy, aten.add, aten.arange, aten.clamp, aten.floor, aten.mul, aten.sub]
    triton_poi_fused__to_copy_add_arange_clamp_floor_mul_sub_3.run(buf4, 3, grid=grid(3), stream=stream0)
    buf10 = empty_strided((3, 1), (1, 1), device='xpu', dtype=torch.int64)
    # Source Nodes: [interpolate], Original ATen: [aten._to_copy, aten.add, aten.clamp, aten.floor]
    triton_poi_fused__to_copy_add_arange_clamp_floor_mul_sub_2.run(buf10, 3, grid=grid(3), stream=stream0)
    buf13 = empty_strided((3, 1), (1, 1), device='xpu', dtype=torch.int64)
    # Source Nodes: [interpolate], Original ATen: [aten._to_copy, aten.add, aten.clamp, aten.floor]
    triton_poi_fused__to_copy_add_arange_clamp_floor_mul_sub_3.run(buf13, 3, grid=grid(3), stream=stream0)
    buf7 = empty_strided((3, 1), (1, 1), device='xpu', dtype=torch.int64)
    # Source Nodes: [interpolate], Original ATen: [aten._to_copy, aten.clamp, aten.floor]
    triton_poi_fused__to_copy_add_arange_clamp_floor_mul_sub_1.run(buf7, 3, grid=grid(3), stream=stream0)
    buf11 = empty_strided((2, 3, 3, 3), (27, 9, 3, 1), device='xpu', dtype=torch.float32)
    buf12 = buf11; del buf11  # reuse
    buf17 = buf12; del buf12  # reuse
    # Source Nodes: [interpolate], Original ATen: [aten._to_copy, aten._unsafe_index, aten.add, aten.arange, aten.clamp, aten.floor, aten.mul, aten.rsub, aten.sub]
    triton_poi_fused__to_copy__unsafe_index_add_arange_clamp_floor_mul_rsub_sub_4.run(buf17, buf0, buf1, primals_1, buf2, buf7, buf10, buf13, buf3, buf4, 54, grid=grid(54), stream=stream0)
    del primals_1
return (buf17, buf0, buf1, buf2, buf3, buf4, buf7, buf10, buf13, )

def benchmark_compiled_module(times=10, repeat=10): from torch._dynamo.testing import rand_strided from torch._inductor.utils import print_performance primals_1 = rand_strided((2, 3, 4, 4), (48, 16, 4, 1), device='xpu:0', dtype=torch.float32) fn = lambda: call([primals_1]) return print_performance(fn, times=times, repeat=repeat)

if name == "main": from torch._inductor.wrapper_benchmark import compiled_module_main compiled_module_main('None', benchmark_compiled_module)

ienkovich commented 7 months ago

Triton successfully compiles all kernels but fails to load one of them. I tried the corresponding SPIRV file in reproducer and zeModuleCreate call fails with ZE_RESULT_ERROR_MODULE_BUILD_FAILURE. The description of this error says that we should look at a build log for more details. I followed documentation example to get the log but got an empty string.

ienkovich commented 7 months ago

spirv-reproducer-issue-829.tar.gz

ienkovich commented 7 months ago

One suspicious place in our generated LLVM IR is a conversion from float to int:

%17 = tail call i2 @llvm.fptosi.sat.i2.f32(float %16), !dbg !26

I'd expect here i32 to be used instead of i2. I'll check why it happens.

ienkovich commented 7 months ago

i2 type usage is a result of LLVM optimizations. Here is the code before optimizations:

  %17 = fptosi float %16 to i32, !dbg !26
  %18 = call i32 @llvm.smax.i32(i32 %17, i32 -2), !dbg !27
  %19 = call i32 @llvm.smin.i32(i32 %18, i32 1), !dbg !28
  %narrow = add nsw i32 %19, 2, !dbg !28

And after AggressiveInstCombinePass:

  %17 = call i2 @llvm.fptosi.sat.i2.f32(float %16), !dbg !26
  %18 = sext i2 %17 to i32, !dbg !26
  %narrow = add nsw i32 %18, 2, !dbg !26

fptosi + llvm.smax.i32 + llvm.smin.i32 are combined into llvm.fptosi.sat.i2.f32 + sext. So we should file an IGC bug to get their feedback.

etiotto commented 7 months ago

I do not see anything incorrect about the transformation the LLVM optimizer did. Using i2 is legal because the original code limits the values of %17 to [-1,2] which fit into a i2 data type. The LLVM/SPIRV translator should handle this case IMO.

There is a PR (https://github.com/KhronosGroup/SPIRV-LLVM-Translator/pull/2500) to fix this. Once it lands this test should work.

MrSidims commented 7 months ago

I do not see anything incorrect about the transformation the LLVM optimizer did. Using i2 is legal because the original code limits the values of %17 to [-1,2] which fit into a i2 data type. The LLVM/SPIRV translator should handle this case IMO.

There is a PR (KhronosGroup/SPIRV-LLVM-Translator#2500) to fix this. Once it lands this test should work.

Does it mean, that we should handle every integer type and vector type? The fact, that this very patch makes this very test to pass, doesn't mean, that on user side compilation would fail due to some optimization generating unsupported in SPIR-V type.

asudarsa commented 7 months ago

I am ok with this one off solution. But we should try to resolve this issue in a better way in the long run. One way I envision is for the backend compiler to natively support non-standard data types, target datalayout is updated to reflect this and then SPIR-V translator can try to support emission of non-standard data types. One caveat is that OCL 3.0 spec does not have support for non-standard data types.

Thanks

ienkovich commented 7 months ago

In this case, our front-end just doesn't have enough control over the fptosi.sat intrinsic usage. So if some intrinsic calls are produced by LLVM optimizations, then it's better to have them supported by the translator, otherwise the whole pipeline would become unreliable.

As for LLVM, it looks like a flaw in the TargetTransformInfo used by default for the target. I don't see how this transformation can be beneficial for non-native types. The cost model for the saturated conversion must be incorrect.

MrSidims commented 7 months ago

One way I envision is for the backend compiler to natively support non-standard data types

What is backend compiler? In SPIR-V (as in any JIT compilation model) we don't know our backend until runtime. The SPIR-V (and ML IR which is used in this very project!) is standardized just for a reason so everyone would agree to support certain functionality, otherwise why don't we just emit LLVM IR and feed it to our devices?

then it's better to have them supported by the translator

I agree, that the translator (or whatever future SPIR-V from LLVM IR producer) should handle any integer types and any vector sizes. I don't agree with handling it case-by-case like this. The reasonable question: wouldn't it be much cheaper to add the appropriate workaround in the pass? IF you are using intel/llvm - it should be simple as that.

asudarsa commented 7 months ago

In this case, our front-end just doesn't have enough control over the fptosi.sat intrinsic usage. So if some intrinsic calls are produced by LLVM optimizations, then it's better to have them supported by the translator, otherwise the whole pipeline would become unreliable.

As for LLVM, it looks like a flaw in the TargetTransformInfo used by default for the target. I don't see how this transformation can be beneficial for non-native types. The cost model for the saturated conversion must be incorrect.

It's a good point about TargetTransformInfo. One of the key differences between SPIR and other LLVM backends is that we do not have a 'registered' target for SPIR. I think we might be relying on the generic TTI info for cost information. Once we have a registered target, we should be able to tweak that and make it customized for our target.

Thanks

ienkovich commented 7 months ago

The reasonable question: wouldn't it be much cheaper to add the appropriate workaround in the pass? IF you are using intel/llvm - it should be simple as that.

We use upstream LLVM.

etaf commented 7 months ago

Verified, @ienkovich @asudarsa @etiotto @whitneywhtsang Thanks for your help!