intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
137 stars 41 forks source link

[DPAS] The GEMM kernel couldn't output the correct result with the DPAS op while the threads_per_warp is 32 #416

Closed chengjunlu closed 5 months ago

chengjunlu commented 8 months ago

The GEMM kernel couldn't output the correct result with the DPAS op while the threads_per_warp is 32.

The threads_per_warp 16 works properly for both the ATSM and PVC. But the threads_per_warp 32 doesn't work as expected.

Need to debug.

tdeng5 commented 7 months ago

We used warp=16 by default. Will try warp=8, 32 etc. in the future when we do performance tuning.

chengjunlu commented 6 months ago

The issue cannot be reproduced on the latest Triton XPU llvm-target branch with the rolling IGC driver. Now we can get correct result of the GEMM kernel with the dpas instruction with threads_per_warp=32.

For the dpas instruction, we just uses half number of scalar to the threads_per_warp=16 as the operands type to A and B and the return value. %260 = call <4 x float> @llvm.genx.GenISA.sub.group.dpas.v4f32.v4f32.v4i16.v4i32(<4 x float> %255, <4 x i16> %234, <4 x i32> %251, i32 10, i32 10, i32 8, i32 8, i1 false) #4, !dbg !418

The kernel size of the threads_per_warp=32:

-rw-rw-r-- 1 jovyan users 15332 Mar 27 00:46 _kernel.spv

The kernel size of the threads_per_warp=16:

-rw-rw-r-- 1 jovyan users 22624 Mar 27 00:48 _kernel.spv

Here is the LLVM IR and GenISA assembly for the Triton GEMM kernel when threads_per_warp=32.

LLVM IR ``` ; ------------------------------------------------ ; OCL_asm4503926daa98ed73_afterUnification.ll ; ------------------------------------------------ target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024-n8:16:32" target triple = "spir64-unknown-unknown" @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM = external addrspace(3) global [0 x i8] @llvm.used = appending global [1 x i8*] [i8* addrspacecast (i8 addrspace(3)* getelementptr inbounds ([0 x i8], [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM, i32 0, i32 0) to i8*)], section "llvm.metadata" ; Function Attrs: convergent nounwind define spir_kernel void @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c(half addrspace(1)* nocapture readonly %0, half addrspace(1)* nocapture readonly %1, half addrspace(1)* nocapture %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8, i8 addrspace(3)* nocapture %9, <8 x i32> %r0, <8 x i32> %payloadHeader, i16 %localIdX, i16 %localIdY, i16 %localIdZ, i8* %privateBase, i32 %bufferOffset, i32 %bufferOffset1, i32 %bufferOffset2) #0 !dbg !377 { call void @llvm.genx.GenISA.CatchAllDebugLine(), !dbg !381 %11 = extractelement <8 x i32> %r0, i32 0 %12 = extractelement <8 x i32> %r0, i32 1 %13 = extractelement <8 x i32> %r0, i32 2 %14 = extractelement <8 x i32> %r0, i32 3 %15 = extractelement <8 x i32> %r0, i32 4 %16 = extractelement <8 x i32> %r0, i32 5 %17 = extractelement <8 x i32> %r0, i32 6 %18 = extractelement <8 x i32> %r0, i32 7 %19 = add i32 %3, 15, !dbg !382 %20 = sdiv i32 %19, 16, !dbg !386 %21 = add i32 %4, 15, !dbg !387 %22 = sdiv i32 %21, 16, !dbg !389 %23 = shl nsw i32 %22, 3, !dbg !390 %24 = sdiv i32 %12, %23, !dbg !391 %freeze = freeze i32 %24, !dbg !392 %25 = shl i32 %freeze, 3, !dbg !392 %26 = sub i32 %20, %25, !dbg !393 %27 = icmp slt i32 %26, 8, !dbg !393 %28 = select i1 %27, i32 %26, i32 8, !dbg !394 %29 = srem i32 %12, %28, !dbg !395 %freeze22 = freeze i32 %29, !dbg !396 %30 = add i32 %25, %freeze22, !dbg !396 %31 = mul i32 %freeze, %23 %.decomposed = sub i32 %12, %31 %32 = sdiv i32 %.decomposed, %28, !dbg !397 %freeze23 = freeze i32 %32, !dbg !398 %33 = shl i32 %30, 4, !dbg !398 %34 = lshr i16 %localIdX, 1, !dbg !399 %35 = and i16 %34, 15, !dbg !399 %36 = zext i16 %35 to i32, !dbg !399 %localIdX8 = zext i16 %localIdX to i32 %37 = shl nuw nsw i32 %localIdX8, 3, !dbg !399 %38 = and i32 %37, 8, !dbg !399 %39 = or i32 %33, %36, !dbg !400 %40 = shl i32 %freeze23, 4, !dbg !401 %41 = or i32 %38, %40, !dbg !402 %42 = add i32 %5, 15, !dbg !403 %43 = sdiv i32 %42, 16, !dbg !405 %44 = icmp sgt i32 %42, 15, !dbg !406 br i1 %44, label %.lr.ph, label %._crit_edge, !dbg !406 .lr.ph: ; preds = %10 %45 = shl i32 %7, 4, !dbg !407 %46 = shl i32 %17, 4, !dbg !408 %47 = or i32 %46, %36, !dbg !409 %48 = mul i32 %47, %7, !dbg !410 %49 = srem i32 %41, %4, !dbg !411 %freeze24 = freeze i32 %49, !dbg !412 %50 = add i32 %48, %freeze24, !dbg !412 %51 = sext i32 %50 to i64, !dbg !413 %52 = getelementptr half, half addrspace(1)* %1, i64 %51, !dbg !413 %53 = srem i32 %39, %3, !dbg !414 %freeze25 = freeze i32 %53, !dbg !415 %54 = mul i32 %freeze25, %6, !dbg !415 %55 = add i32 %54, %46, !dbg !409 %56 = add i32 %55, %38, !dbg !416 %57 = sext i32 %56 to i64, !dbg !417 %58 = getelementptr half, half addrspace(1)* %0, i64 %57, !dbg !417 %59 = sext i32 %45 to i64 br label %60, !dbg !406 60: ; preds = %60, %.lr.ph %.pn3153 = phi half addrspace(1)* [ %52, %.lr.ph ], [ %271, %60 ] %.pn1552 = phi half addrspace(1)* [ %58, %.lr.ph ], [ %270, %60 ] %61 = phi float [ 0.000000e+00, %.lr.ph ], [ %261, %60 ], !dbg !418 %62 = phi float [ 0.000000e+00, %.lr.ph ], [ %262, %60 ], !dbg !418 %63 = phi float [ 0.000000e+00, %.lr.ph ], [ %263, %60 ], !dbg !418 %64 = phi float [ 0.000000e+00, %.lr.ph ], [ %264, %60 ], !dbg !418 %65 = phi float [ 0.000000e+00, %.lr.ph ], [ %266, %60 ], !dbg !418 %66 = phi float [ 0.000000e+00, %.lr.ph ], [ %267, %60 ], !dbg !418 %67 = phi float [ 0.000000e+00, %.lr.ph ], [ %268, %60 ], !dbg !418 %68 = phi float [ 0.000000e+00, %.lr.ph ], [ %269, %60 ], !dbg !418 %69 = phi i32 [ 0, %.lr.ph ], [ %272, %60 ] %70 = bitcast half addrspace(1)* %.pn1552 to i32 addrspace(1)*, !dbg !419 %71 = load i32, i32 addrspace(1)* %70, align 16, !dbg !419 %72 = getelementptr inbounds half, half addrspace(1)* %.pn1552, i64 2, !dbg !419 %73 = bitcast half addrspace(1)* %72 to i32 addrspace(1)*, !dbg !419 %74 = load i32, i32 addrspace(1)* %73, align 4, !dbg !419 %75 = getelementptr inbounds half, half addrspace(1)* %.pn1552, i64 4, !dbg !419 %76 = bitcast half addrspace(1)* %75 to i32 addrspace(1)*, !dbg !419 %77 = load i32, i32 addrspace(1)* %76, align 8, !dbg !419 %78 = getelementptr inbounds half, half addrspace(1)* %.pn1552, i64 6, !dbg !419 %79 = bitcast half addrspace(1)* %78 to i32 addrspace(1)*, !dbg !419 %80 = load i32, i32 addrspace(1)* %79, align 4, !dbg !419 %81 = trunc i32 %71 to i16, !dbg !419 %extelt.offset = lshr i32 %71, 16, !dbg !419 %82 = trunc i32 %extelt.offset to i16, !dbg !419 %83 = trunc i32 %74 to i16, !dbg !419 %extelt.offset45 = lshr i32 %74, 16, !dbg !419 %84 = trunc i32 %extelt.offset45 to i16, !dbg !419 %85 = trunc i32 %77 to i16, !dbg !419 %extelt.offset46 = lshr i32 %77, 16, !dbg !419 %86 = trunc i32 %extelt.offset46 to i16, !dbg !419 %87 = trunc i32 %80 to i16, !dbg !419 %extelt.offset47 = lshr i32 %80, 16, !dbg !419 %88 = trunc i32 %extelt.offset47 to i16, !dbg !419 call void @llvm.genx.GenISA.memoryfence(i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false) call void @llvm.genx.GenISA.threadgroupbarrier() %retval.0.i7 = zext i16 %localIdX to i64 %89 = shl nuw nsw i64 %retval.0.i7, 3, !dbg !419 %90 = and i64 %89, 248, !dbg !419 %91 = bitcast [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM to half addrspace(3)*, !dbg !419 %92 = getelementptr half, half addrspace(3)* %91, i64 %90, !dbg !419 %93 = getelementptr half, half addrspace(3)* %92, i64 1, !dbg !419 %94 = getelementptr half, half addrspace(3)* %92, i64 2, !dbg !419 %95 = getelementptr half, half addrspace(3)* %92, i64 3, !dbg !419 %96 = getelementptr half, half addrspace(3)* %92, i64 4, !dbg !419 %97 = getelementptr half, half addrspace(3)* %92, i64 5, !dbg !419 %98 = getelementptr half, half addrspace(3)* %92, i64 6, !dbg !419 %99 = getelementptr half, half addrspace(3)* %92, i64 7, !dbg !419 %100 = bitcast half addrspace(3)* %92 to i16 addrspace(3)*, !dbg !419 store i16 %81, i16 addrspace(3)* %100, align 2, !dbg !419 %101 = bitcast half addrspace(3)* %93 to i16 addrspace(3)*, !dbg !419 store i16 %82, i16 addrspace(3)* %101, align 2, !dbg !419 %102 = bitcast half addrspace(3)* %94 to i16 addrspace(3)*, !dbg !419 store i16 %83, i16 addrspace(3)* %102, align 2, !dbg !419 %103 = bitcast half addrspace(3)* %95 to i16 addrspace(3)*, !dbg !419 store i16 %84, i16 addrspace(3)* %103, align 2, !dbg !419 %104 = bitcast half addrspace(3)* %96 to i16 addrspace(3)*, !dbg !419 store i16 %85, i16 addrspace(3)* %104, align 2, !dbg !419 %105 = bitcast half addrspace(3)* %97 to i16 addrspace(3)*, !dbg !419 store i16 %86, i16 addrspace(3)* %105, align 2, !dbg !419 %106 = bitcast half addrspace(3)* %98 to i16 addrspace(3)*, !dbg !419 store i16 %87, i16 addrspace(3)* %106, align 2, !dbg !419 %107 = bitcast half addrspace(3)* %99 to i16 addrspace(3)*, !dbg !419 store i16 %88, i16 addrspace(3)* %107, align 2, !dbg !419 %108 = bitcast half addrspace(1)* %.pn3153 to i32 addrspace(1)*, !dbg !420 %109 = load i32, i32 addrspace(1)* %108, align 16, !dbg !420 %110 = getelementptr inbounds half, half addrspace(1)* %.pn3153, i64 2, !dbg !420 %111 = bitcast half addrspace(1)* %110 to i32 addrspace(1)*, !dbg !420 %112 = load i32, i32 addrspace(1)* %111, align 4, !dbg !420 %113 = getelementptr inbounds half, half addrspace(1)* %.pn3153, i64 4, !dbg !420 %114 = bitcast half addrspace(1)* %113 to i32 addrspace(1)*, !dbg !420 %115 = load i32, i32 addrspace(1)* %114, align 8, !dbg !420 %116 = getelementptr inbounds half, half addrspace(1)* %.pn3153, i64 6, !dbg !420 %117 = bitcast half addrspace(1)* %116 to i32 addrspace(1)*, !dbg !420 %118 = load i32, i32 addrspace(1)* %117, align 4, !dbg !420 %119 = trunc i32 %109 to i16, !dbg !420 %extelt.offset48 = lshr i32 %109, 16, !dbg !420 %120 = trunc i32 %extelt.offset48 to i16, !dbg !420 %121 = trunc i32 %112 to i16, !dbg !420 %extelt.offset49 = lshr i32 %112, 16, !dbg !420 %122 = trunc i32 %extelt.offset49 to i16, !dbg !420 %123 = trunc i32 %115 to i16, !dbg !420 %extelt.offset50 = lshr i32 %115, 16, !dbg !420 %124 = trunc i32 %extelt.offset50 to i16, !dbg !420 %125 = trunc i32 %118 to i16, !dbg !420 %extelt.offset51 = lshr i32 %118, 16, !dbg !420 %126 = trunc i32 %extelt.offset51 to i16, !dbg !420 %retval.0.i11 = zext i16 %localIdX to i64 %127 = shl nuw nsw i64 %retval.0.i11, 3, !dbg !420 %128 = and i64 %127, 248, !dbg !420 %129 = getelementptr inbounds [0 x i8], [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM, i64 0, i64 1024, !dbg !420 %130 = bitcast i8 addrspace(3)* %129 to half addrspace(3)*, !dbg !420 %131 = getelementptr half, half addrspace(3)* %130, i64 %128, !dbg !420 %132 = getelementptr half, half addrspace(3)* %131, i64 1, !dbg !420 %133 = getelementptr half, half addrspace(3)* %131, i64 2, !dbg !420 %134 = getelementptr half, half addrspace(3)* %131, i64 3, !dbg !420 %135 = getelementptr half, half addrspace(3)* %131, i64 4, !dbg !420 %136 = getelementptr half, half addrspace(3)* %131, i64 5, !dbg !420 %137 = getelementptr half, half addrspace(3)* %131, i64 6, !dbg !420 %138 = getelementptr half, half addrspace(3)* %131, i64 7, !dbg !420 %139 = bitcast half addrspace(3)* %131 to i16 addrspace(3)*, !dbg !420 store i16 %119, i16 addrspace(3)* %139, align 2, !dbg !420 %140 = bitcast half addrspace(3)* %132 to i16 addrspace(3)*, !dbg !420 store i16 %120, i16 addrspace(3)* %140, align 2, !dbg !420 %141 = bitcast half addrspace(3)* %133 to i16 addrspace(3)*, !dbg !420 store i16 %121, i16 addrspace(3)* %141, align 2, !dbg !420 %142 = bitcast half addrspace(3)* %134 to i16 addrspace(3)*, !dbg !420 store i16 %122, i16 addrspace(3)* %142, align 2, !dbg !420 %143 = bitcast half addrspace(3)* %135 to i16 addrspace(3)*, !dbg !420 store i16 %123, i16 addrspace(3)* %143, align 2, !dbg !420 %144 = bitcast half addrspace(3)* %136 to i16 addrspace(3)*, !dbg !420 store i16 %124, i16 addrspace(3)* %144, align 2, !dbg !420 %145 = bitcast half addrspace(3)* %137 to i16 addrspace(3)*, !dbg !420 store i16 %125, i16 addrspace(3)* %145, align 2, !dbg !420 %146 = bitcast half addrspace(3)* %138 to i16 addrspace(3)*, !dbg !420 store i16 %126, i16 addrspace(3)* %146, align 2, !dbg !420 call void @llvm.genx.GenISA.memoryfence(i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false) call void @llvm.genx.GenISA.threadgroupbarrier() %147 = and i16 %localIdX, 31, !dbg !419 %urem = zext i16 %147 to i64, !dbg !419 %148 = and i16 %localIdX, 31, !dbg !419 %149 = zext i16 %148 to i32, !dbg !419 %150 = or i32 %149, 32, !dbg !419 %151 = or i32 %149, 64, !dbg !419 %152 = or i32 %149, 96, !dbg !419 %153 = bitcast [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM to half addrspace(3)*, !dbg !419 %154 = getelementptr half, half addrspace(3)* %153, i64 %urem, !dbg !419 %155 = zext i32 %150 to i64, !dbg !419 %156 = bitcast [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM to half addrspace(3)*, !dbg !419 %157 = getelementptr half, half addrspace(3)* %156, i64 %155, !dbg !419 %158 = zext i32 %151 to i64, !dbg !419 %159 = bitcast [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM to half addrspace(3)*, !dbg !419 %160 = getelementptr half, half addrspace(3)* %159, i64 %158, !dbg !419 %161 = zext i32 %152 to i64, !dbg !419 %162 = bitcast [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM to half addrspace(3)*, !dbg !419 %163 = getelementptr half, half addrspace(3)* %162, i64 %161, !dbg !419 %164 = load half, half addrspace(3)* %154, align 2, !dbg !419 %165 = load half, half addrspace(3)* %157, align 2, !dbg !419 %166 = load half, half addrspace(3)* %160, align 2, !dbg !419 %167 = load half, half addrspace(3)* %163, align 2, !dbg !419 %168 = getelementptr half, half addrspace(3)* %154, i64 128, !dbg !419 %169 = load half, half addrspace(3)* %168, align 2, !dbg !419 %170 = getelementptr half, half addrspace(3)* %157, i64 128, !dbg !419 %171 = load half, half addrspace(3)* %170, align 2, !dbg !419 %172 = getelementptr half, half addrspace(3)* %160, i64 128, !dbg !419 %173 = load half, half addrspace(3)* %172, align 2, !dbg !419 %174 = getelementptr half, half addrspace(3)* %163, i64 128, !dbg !419 %175 = load half, half addrspace(3)* %174, align 2, !dbg !419 %localIdX12 = zext i16 %localIdX to i32 %176 = and i32 %localIdX12, 15, !dbg !420 %177 = shl nuw nsw i32 %localIdX12, 1, !dbg !420 %178 = and i32 %177, 32, !dbg !420 %179 = or i32 %178, %176, !dbg !420 %180 = or i32 %179, 16, !dbg !420 %181 = or i32 %179, 64, !dbg !420 %182 = or i32 %179, 80, !dbg !420 %183 = or i32 %179, 128, !dbg !420 %184 = or i32 %179, 144, !dbg !420 %185 = or i32 %179, 192, !dbg !420 %186 = or i32 %179, 208, !dbg !420 %187 = zext i32 %179 to i64, !dbg !420 %188 = getelementptr inbounds [0 x i8], [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM, i64 0, i64 1024, !dbg !420 %189 = bitcast i8 addrspace(3)* %188 to half addrspace(3)*, !dbg !420 %190 = getelementptr half, half addrspace(3)* %189, i64 %187, !dbg !420 %191 = zext i32 %180 to i64, !dbg !420 %192 = getelementptr inbounds [0 x i8], [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM, i64 0, i64 1024, !dbg !420 %193 = bitcast i8 addrspace(3)* %192 to half addrspace(3)*, !dbg !420 %194 = getelementptr half, half addrspace(3)* %193, i64 %191, !dbg !420 %195 = zext i32 %181 to i64, !dbg !420 %196 = getelementptr inbounds [0 x i8], [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM, i64 0, i64 1024, !dbg !420 %197 = bitcast i8 addrspace(3)* %196 to half addrspace(3)*, !dbg !420 %198 = getelementptr half, half addrspace(3)* %197, i64 %195, !dbg !420 %199 = zext i32 %182 to i64, !dbg !420 %200 = getelementptr inbounds [0 x i8], [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM, i64 0, i64 1024, !dbg !420 %201 = bitcast i8 addrspace(3)* %200 to half addrspace(3)*, !dbg !420 %202 = getelementptr half, half addrspace(3)* %201, i64 %199, !dbg !420 %203 = zext i32 %183 to i64, !dbg !420 %204 = getelementptr inbounds [0 x i8], [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM, i64 0, i64 1024, !dbg !420 %205 = bitcast i8 addrspace(3)* %204 to half addrspace(3)*, !dbg !420 %206 = getelementptr half, half addrspace(3)* %205, i64 %203, !dbg !420 %207 = zext i32 %184 to i64, !dbg !420 %208 = getelementptr inbounds [0 x i8], [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM, i64 0, i64 1024, !dbg !420 %209 = bitcast i8 addrspace(3)* %208 to half addrspace(3)*, !dbg !420 %210 = getelementptr half, half addrspace(3)* %209, i64 %207, !dbg !420 %211 = zext i32 %185 to i64, !dbg !420 %212 = getelementptr inbounds [0 x i8], [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM, i64 0, i64 1024, !dbg !420 %213 = bitcast i8 addrspace(3)* %212 to half addrspace(3)*, !dbg !420 %214 = getelementptr half, half addrspace(3)* %213, i64 %211, !dbg !420 %215 = zext i32 %186 to i64, !dbg !420 %216 = getelementptr inbounds [0 x i8], [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM, i64 0, i64 1024, !dbg !420 %217 = bitcast i8 addrspace(3)* %216 to half addrspace(3)*, !dbg !420 %218 = getelementptr half, half addrspace(3)* %217, i64 %215, !dbg !420 %219 = load half, half addrspace(3)* %190, align 2, !dbg !420 %220 = load half, half addrspace(3)* %194, align 2, !dbg !420 %221 = load half, half addrspace(3)* %198, align 2, !dbg !420 %222 = load half, half addrspace(3)* %202, align 2, !dbg !420 %223 = load half, half addrspace(3)* %206, align 2, !dbg !420 %224 = load half, half addrspace(3)* %210, align 2, !dbg !420 %225 = load half, half addrspace(3)* %214, align 2, !dbg !420 %226 = load half, half addrspace(3)* %218, align 2, !dbg !420 %227 = bitcast half %164 to i16, !dbg !418 %228 = bitcast half %165 to i16, !dbg !418 %229 = bitcast half %166 to i16, !dbg !418 %230 = bitcast half %167 to i16, !dbg !418 %231 = insertelement <4 x i16> undef, i16 %227, i32 0, !dbg !418 %232 = insertelement <4 x i16> %231, i16 %228, i32 1, !dbg !418 %233 = insertelement <4 x i16> %232, i16 %229, i32 2, !dbg !418 %234 = insertelement <4 x i16> %233, i16 %230, i32 3, !dbg !418 %235 = bitcast half %169 to i16, !dbg !418 %236 = bitcast half %171 to i16, !dbg !418 %237 = bitcast half %173 to i16, !dbg !418 %238 = bitcast half %175 to i16, !dbg !418 %239 = insertelement <4 x i16> undef, i16 %235, i32 0, !dbg !418 %240 = insertelement <4 x i16> %239, i16 %236, i32 1, !dbg !418 %241 = insertelement <4 x i16> %240, i16 %237, i32 2, !dbg !418 %242 = insertelement <4 x i16> %241, i16 %238, i32 3, !dbg !418 %243 = insertelement <8 x half> undef, half %219, i32 0, !dbg !418 %244 = insertelement <8 x half> %243, half %220, i32 1, !dbg !418 %245 = insertelement <8 x half> %244, half %221, i32 2, !dbg !418 %246 = insertelement <8 x half> %245, half %222, i32 3, !dbg !418 %247 = insertelement <8 x half> %246, half %223, i32 4, !dbg !418 %248 = insertelement <8 x half> %247, half %224, i32 5, !dbg !418 %249 = insertelement <8 x half> %248, half %225, i32 6, !dbg !418 %250 = insertelement <8 x half> %249, half %226, i32 7, !dbg !418 %251 = bitcast <8 x half> %250 to <4 x i32>, !dbg !418 %252 = insertelement <4 x float> undef, float %61, i32 0, !dbg !418 %253 = insertelement <4 x float> %252, float %62, i32 1, !dbg !418 %254 = insertelement <4 x float> %253, float %63, i32 2, !dbg !418 %255 = insertelement <4 x float> %254, float %64, i32 3, !dbg !418 %256 = insertelement <4 x float> undef, float %65, i32 0, !dbg !418 %257 = insertelement <4 x float> %256, float %66, i32 1, !dbg !418 %258 = insertelement <4 x float> %257, float %67, i32 2, !dbg !418 %259 = insertelement <4 x float> %258, float %68, i32 3, !dbg !418 %260 = call <4 x float> @llvm.genx.GenISA.sub.group.dpas.v4f32.v4f32.v4i16.v4i32(<4 x float> %255, <4 x i16> %234, <4 x i32> %251, i32 10, i32 10, i32 8, i32 8, i1 false) #4, !dbg !418 %261 = extractelement <4 x float> %260, i32 0, !dbg !418 %262 = extractelement <4 x float> %260, i32 1, !dbg !418 %263 = extractelement <4 x float> %260, i32 2, !dbg !418 %264 = extractelement <4 x float> %260, i32 3, !dbg !418 %265 = call <4 x float> @llvm.genx.GenISA.sub.group.dpas.v4f32.v4f32.v4i16.v4i32(<4 x float> %259, <4 x i16> %242, <4 x i32> %251, i32 10, i32 10, i32 8, i32 8, i1 false) #4, !dbg !418 %266 = extractelement <4 x float> %265, i32 0, !dbg !418 %267 = extractelement <4 x float> %265, i32 1, !dbg !418 %268 = extractelement <4 x float> %265, i32 2, !dbg !418 %269 = extractelement <4 x float> %265, i32 3, !dbg !418 %270 = getelementptr half, half addrspace(1)* %.pn1552, i64 16, !dbg !421 %271 = getelementptr half, half addrspace(1)* %.pn3153, i64 %59, !dbg !422 %272 = add nuw nsw i32 %69, 1, !dbg !406 %exitcond.not = icmp eq i32 %272, %43, !dbg !406 br i1 %exitcond.not, label %._crit_edge.loopexit, label %60, !dbg !406 ._crit_edge.loopexit: ; preds = %60 %273 = fptrunc float %261 to half, !dbg !423 %274 = fptrunc float %262 to half, !dbg !423 %275 = fptrunc float %263 to half, !dbg !423 %276 = fptrunc float %264 to half, !dbg !423 %277 = fptrunc float %266 to half, !dbg !423 %278 = fptrunc float %267 to half, !dbg !423 %279 = fptrunc float %268 to half, !dbg !423 %280 = fptrunc float %269 to half, !dbg !423 br label %._crit_edge, !dbg !423 ._crit_edge: ; preds = %._crit_edge.loopexit, %10 %281 = phi half [ 0xH0000, %10 ], [ %273, %._crit_edge.loopexit ], !dbg !424 %282 = phi half [ 0xH0000, %10 ], [ %274, %._crit_edge.loopexit ], !dbg !424 %283 = phi half [ 0xH0000, %10 ], [ %275, %._crit_edge.loopexit ], !dbg !424 %284 = phi half [ 0xH0000, %10 ], [ %276, %._crit_edge.loopexit ], !dbg !424 %285 = phi half [ 0xH0000, %10 ], [ %277, %._crit_edge.loopexit ], !dbg !424 %286 = phi half [ 0xH0000, %10 ], [ %278, %._crit_edge.loopexit ], !dbg !424 %287 = phi half [ 0xH0000, %10 ], [ %279, %._crit_edge.loopexit ], !dbg !424 %288 = phi half [ 0xH0000, %10 ], [ %280, %._crit_edge.loopexit ], !dbg !424 %289 = icmp slt i32 %39, %3, !dbg !425 %290 = icmp slt i32 %41, %4, !dbg !426 %291 = and i1 %289, %290, !dbg !427 call void @llvm.genx.GenISA.memoryfence(i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false) call void @llvm.genx.GenISA.threadgroupbarrier() %retval.0.i17 = zext i16 %localIdX to i64 %292 = and i64 %retval.0.i17, 15, !dbg !424 %293 = and i64 %retval.0.i17, 16, !dbg !424 %.not = icmp eq i64 %293, 0, !dbg !424 %294 = select i1 %.not, i64 0, i64 24, !dbg !424 %295 = bitcast [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM to half addrspace(3)*, !dbg !424 %296 = getelementptr half, half addrspace(3)* %295, i64 %294, !dbg !424 %297 = getelementptr half, half addrspace(3)* %296, i64 %292, !dbg !424 %298 = insertelement <1 x half> undef, half %281, i32 0, !dbg !424 %299 = bitcast half addrspace(3)* %297 to <1 x half> addrspace(3)*, !dbg !424 store <1 x half> %298, <1 x half> addrspace(3)* %299, align 2, !dbg !424 %retval.0.i19 = zext i16 %localIdX to i64 %300 = and i64 %retval.0.i19, 15, !dbg !424 %301 = and i64 %retval.0.i19, 16, !dbg !424 %.not32 = icmp eq i64 %301, 0, !dbg !424 %302 = select i1 %.not32, i64 48, i64 72, !dbg !424 %303 = bitcast [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM to half addrspace(3)*, !dbg !424 %304 = getelementptr half, half addrspace(3)* %303, i64 %302, !dbg !424 %305 = getelementptr half, half addrspace(3)* %304, i64 %300, !dbg !424 %306 = insertelement <1 x half> undef, half %282, i32 0, !dbg !424 %307 = bitcast half addrspace(3)* %305 to <1 x half> addrspace(3)*, !dbg !424 store <1 x half> %306, <1 x half> addrspace(3)* %307, align 2, !dbg !424 %retval.0.i21 = zext i16 %localIdX to i64 %308 = and i64 %retval.0.i21, 15, !dbg !424 %309 = and i64 %retval.0.i21, 16, !dbg !424 %.not33 = icmp eq i64 %309, 0, !dbg !424 %310 = select i1 %.not33, i64 96, i64 120, !dbg !424 %311 = bitcast [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM to half addrspace(3)*, !dbg !424 %312 = getelementptr half, half addrspace(3)* %311, i64 %310, !dbg !424 %313 = getelementptr half, half addrspace(3)* %312, i64 %308, !dbg !424 %314 = insertelement <1 x half> undef, half %283, i32 0, !dbg !424 %315 = bitcast half addrspace(3)* %313 to <1 x half> addrspace(3)*, !dbg !424 store <1 x half> %314, <1 x half> addrspace(3)* %315, align 2, !dbg !424 %retval.0.i23 = zext i16 %localIdX to i64 %316 = and i64 %retval.0.i23, 15, !dbg !424 %317 = and i64 %retval.0.i23, 16, !dbg !424 %.not34 = icmp eq i64 %317, 0, !dbg !424 %318 = select i1 %.not34, i64 144, i64 168, !dbg !424 %319 = bitcast [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM to half addrspace(3)*, !dbg !424 %320 = getelementptr half, half addrspace(3)* %319, i64 %318, !dbg !424 %321 = getelementptr half, half addrspace(3)* %320, i64 %316, !dbg !424 %322 = insertelement <1 x half> undef, half %284, i32 0, !dbg !424 %323 = bitcast half addrspace(3)* %321 to <1 x half> addrspace(3)*, !dbg !424 store <1 x half> %322, <1 x half> addrspace(3)* %323, align 2, !dbg !424 %retval.0.i25 = zext i16 %localIdX to i64 %324 = and i64 %retval.0.i25, 15, !dbg !424 %325 = and i64 %retval.0.i25, 16, !dbg !424 %.not35 = icmp eq i64 %325, 0, !dbg !424 %326 = select i1 %.not35, i64 192, i64 216, !dbg !424 %327 = bitcast [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM to half addrspace(3)*, !dbg !424 %328 = getelementptr half, half addrspace(3)* %327, i64 %326, !dbg !424 %329 = getelementptr half, half addrspace(3)* %328, i64 %324, !dbg !424 %330 = insertelement <1 x half> undef, half %285, i32 0, !dbg !424 %331 = bitcast half addrspace(3)* %329 to <1 x half> addrspace(3)*, !dbg !424 store <1 x half> %330, <1 x half> addrspace(3)* %331, align 2, !dbg !424 %retval.0.i27 = zext i16 %localIdX to i64 %332 = and i64 %retval.0.i27, 15, !dbg !424 %333 = and i64 %retval.0.i27, 16, !dbg !424 %.not36 = icmp eq i64 %333, 0, !dbg !424 %334 = select i1 %.not36, i64 240, i64 264, !dbg !424 %335 = bitcast [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM to half addrspace(3)*, !dbg !424 %336 = getelementptr half, half addrspace(3)* %335, i64 %334, !dbg !424 %337 = getelementptr half, half addrspace(3)* %336, i64 %332, !dbg !424 %338 = insertelement <1 x half> undef, half %286, i32 0, !dbg !424 %339 = bitcast half addrspace(3)* %337 to <1 x half> addrspace(3)*, !dbg !424 store <1 x half> %338, <1 x half> addrspace(3)* %339, align 2, !dbg !424 %retval.0.i29 = zext i16 %localIdX to i64 %340 = and i64 %retval.0.i29, 15, !dbg !424 %341 = and i64 %retval.0.i29, 16, !dbg !424 %.not37 = icmp eq i64 %341, 0, !dbg !424 %342 = select i1 %.not37, i64 288, i64 312, !dbg !424 %343 = bitcast [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM to half addrspace(3)*, !dbg !424 %344 = getelementptr half, half addrspace(3)* %343, i64 %342, !dbg !424 %345 = getelementptr half, half addrspace(3)* %344, i64 %340, !dbg !424 %346 = insertelement <1 x half> undef, half %287, i32 0, !dbg !424 %347 = bitcast half addrspace(3)* %345 to <1 x half> addrspace(3)*, !dbg !424 store <1 x half> %346, <1 x half> addrspace(3)* %347, align 2, !dbg !424 %retval.0.i31 = zext i16 %localIdX to i64 %348 = and i64 %retval.0.i31, 15, !dbg !424 %349 = and i64 %retval.0.i31, 16, !dbg !424 %.not38 = icmp eq i64 %349, 0, !dbg !424 %350 = select i1 %.not38, i64 336, i64 360, !dbg !424 %351 = bitcast [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM to half addrspace(3)*, !dbg !424 %352 = getelementptr half, half addrspace(3)* %351, i64 %350, !dbg !424 %353 = getelementptr half, half addrspace(3)* %352, i64 %348, !dbg !424 %354 = insertelement <1 x half> undef, half %288, i32 0, !dbg !424 %355 = bitcast half addrspace(3)* %353 to <1 x half> addrspace(3)*, !dbg !424 store <1 x half> %354, <1 x half> addrspace(3)* %355, align 2, !dbg !424 call void @llvm.genx.GenISA.memoryfence(i1 true, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false) call void @llvm.genx.GenISA.threadgroupbarrier() %retval.0.i9 = zext i16 %localIdX to i64 %356 = lshr i64 %retval.0.i9, 1, !dbg !424 %357 = and i64 %356, 15, !dbg !424 %358 = shl nuw nsw i64 %retval.0.i9, 3, !dbg !424 %359 = and i64 %358, 8, !dbg !424 %360 = mul nuw nsw i64 %357, 24, !dbg !424 %361 = bitcast [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM to half addrspace(3)*, !dbg !424 %362 = getelementptr half, half addrspace(3)* %361, i64 %360, !dbg !424 %363 = getelementptr half, half addrspace(3)* %362, i64 %359, !dbg !424 %364 = bitcast half addrspace(3)* %363 to <4 x i32> addrspace(3)*, !dbg !424 %365 = load <4 x i32>, <4 x i32> addrspace(3)* %364, align 16, !dbg !424 br i1 %291, label %366, label %372, !dbg !424 366: ; preds = %._crit_edge %367 = mul i32 %39, %8, !dbg !428 %368 = add i32 %367, %41, !dbg !429 %369 = sext i32 %368 to i64, !dbg !430 %370 = getelementptr half, half addrspace(1)* %2, i64 %369, !dbg !430 %371 = bitcast half addrspace(1)* %370 to <4 x i32> addrspace(1)*, !dbg !424 store <4 x i32> %365, <4 x i32> addrspace(1)* %371, align 16, !dbg !424 br label %372, !dbg !424 372: ; preds = %366, %._crit_edge ret void, !dbg !431 } ; Function Desc: ; Output: ; Arg 0: ; Arg 1: ; Arg 2: ; Arg 3: ; Arg 4: ; Arg 5: ; Arg 6: ; Arg 7: declare <4 x float> @llvm.genx.GenISA.sub.group.dpas.v4f32.v4f32.v4i16.v4i32(<4 x float>, <4 x i16>, <4 x i32>, i32, i32, i32, i32, i1) ; Function Attrs: convergent declare spir_func void @__builtin_IB_thread_group_barrier() local_unnamed_addr #1 ; Function Attrs: convergent declare spir_func void @__builtin_IB_memfence(i1 noundef zeroext, i1 noundef zeroext, i1 noundef zeroext, i1 noundef zeroext, i1 noundef zeroext, i1 noundef zeroext, i1 noundef zeroext, i1 noundef zeroext) local_unnamed_addr #1 ; Function Attrs: convergent mustprogress nofree nounwind readnone willreturn declare spir_func i32 @__builtin_IB_get_group_id(i32 noundef) local_unnamed_addr #2 ; Function Attrs: convergent mustprogress nofree nounwind readnone willreturn declare spir_func i32 @__builtin_IB_get_local_id_x() local_unnamed_addr #2 ; Function Desc: ; Output: ; Arg 0: ; Arg 1: ; Arg 2: ; Arg 3: ; Arg 4: ; Arg 5: ; Arg 6: ; Arg 7: ; Function Attrs: convergent nounwind declare void @llvm.genx.GenISA.memoryfence(i1, i1, i1, i1, i1, i1, i1, i1) #0 ; Function Desc: ; Output: ; Function Attrs: convergent nounwind declare void @llvm.genx.GenISA.threadgroupbarrier() #0 ; Function Desc: ; Output: ; Function Attrs: nounwind readnone declare void @llvm.genx.GenISA.CatchAllDebugLine() #3 attributes #0 = { convergent nounwind } attributes #1 = { convergent "frame-pointer"="none" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } attributes #2 = { convergent mustprogress nofree nounwind readnone willreturn "frame-pointer"="none" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } attributes #3 = { nounwind readnone } attributes #4 = { nounwind } !llvm.module.flags = !{!0, !1, !2} !llvm.dbg.cu = !{!3} !spirv.MemoryModel = !{!5} !spirv.Source = !{!6} !spirv.Generator = !{!7} !igc.functions = !{!8} !IGCMetadata = !{!25} !opencl.ocl.version = !{!375, !375, !375, !375, !375} !opencl.spir.version = !{!375, !375, !375, !375, !375} !llvm.ident = !{!376, !376, !376, !376, !376} !0 = !{i32 7, !"Dwarf Version", i32 0} !1 = !{i32 2, !"Debug Info Version", i32 3} !2 = !{i32 1, !"wchar_size", i32 4} !3 = distinct !DICompileUnit(language: DW_LANG_OpenCL, file: !4, producer: "triton", isOptimized: false, runtimeVersion: 0, emissionKind: FullDebug) !4 = !DIFile(filename: "matmul.py", directory: "/home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/ops") !5 = !{i32 2, i32 2} !6 = !{i32 4, i32 100000} !7 = !{i16 6, i16 14} !8 = !{void (half addrspace(1)*, half addrspace(1)*, half addrspace(1)*, i32, i32, i32, i32, i32, i32, i8 addrspace(3)*, <8 x i32>, <8 x i32>, i16, i16, i16, i8*, i32, i32, i32)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c, !9} !9 = !{!10, !11, !24} !10 = !{!"function_type", i32 0} !11 = !{!"implicit_arg_desc", !12, !13, !14, !15, !16, !17, !18, !20, !22} !12 = !{i32 0} !13 = !{i32 1} !14 = !{i32 7} !15 = !{i32 8} !16 = !{i32 9} !17 = !{i32 12} !18 = !{i32 14, !19} !19 = !{!"explicit_arg_num", i32 0} !20 = !{i32 14, !21} !21 = !{!"explicit_arg_num", i32 1} !22 = !{i32 14, !23} !23 = !{!"explicit_arg_num", i32 2} !24 = !{!"sub_group_size", i32 32} !25 = !{!"ModuleMD", !26, !27, !97, !242, !272, !288, !305, !315, !317, !318, !331, !332, !333, !334, !338, !339, !340, !341, !342, !343, !344, !345, !346, !347, !348, !349, !350, !351, !353, !357, !358, !359, !360, !361, !362, !363, !364, !365, !167, !366, !367, !368, !370, !373, !374} !26 = !{!"isPrecise", i1 false} !27 = !{!"compOpt", !28, !29, !30, !31, !32, !33, !34, !35, !36, !37, !38, !39, !40, !41, !42, !43, !44, !45, !46, !47, !48, !49, !50, !51, !52, !53, !54, !55, !56, !57, !58, !59, !60, !61, !62, !63, !64, !65, !66, !67, !68, !69, !70, !71, !72, !73, !74, !75, !76, !77, !78, !79, !80, !81, !82, !83, !84, !85, !86, !87, !88, !89, !90, !91, !92, !93, !94, !95, !96} !28 = !{!"DenormsAreZero", i1 false} !29 = !{!"CorrectlyRoundedDivSqrt", i1 false} !30 = !{!"OptDisable", i1 false} !31 = !{!"MadEnable", i1 false} !32 = !{!"NoSignedZeros", i1 false} !33 = !{!"NoNaNs", i1 false} !34 = !{!"FloatRoundingMode", i32 0} !35 = !{!"FloatCvtIntRoundingMode", i32 3} !36 = !{!"LoadCacheDefault", i32 4} !37 = !{!"StoreCacheDefault", i32 2} !38 = !{!"VISAPreSchedRPThreshold", i32 0} !39 = !{!"SetLoopUnrollThreshold", i32 0} !40 = !{!"UnsafeMathOptimizations", i1 false} !41 = !{!"disableCustomUnsafeOpts", i1 false} !42 = !{!"disableReducePow", i1 false} !43 = !{!"FiniteMathOnly", i1 false} !44 = !{!"FastRelaxedMath", i1 false} !45 = !{!"DashGSpecified", i1 false} !46 = !{!"FastCompilation", i1 false} !47 = !{!"UseScratchSpacePrivateMemory", i1 true} !48 = !{!"RelaxedBuiltins", i1 false} !49 = !{!"SubgroupIndependentForwardProgressRequired", i1 true} !50 = !{!"GreaterThan2GBBufferRequired", i1 true} !51 = !{!"GreaterThan4GBBufferRequired", i1 true} !52 = !{!"DisableA64WA", i1 false} !53 = !{!"ForceEnableA64WA", i1 false} !54 = !{!"PushConstantsEnable", i1 true} !55 = !{!"HasPositivePointerOffset", i1 false} !56 = !{!"HasBufferOffsetArg", i1 true} !57 = !{!"BufferOffsetArgOptional", i1 true} !58 = !{!"replaceGlobalOffsetsByZero", i1 false} !59 = !{!"forcePixelShaderSIMDMode", i32 0} !60 = !{!"pixelShaderDoNotAbortOnSpill", i1 false} !61 = !{!"UniformWGS", i1 false} !62 = !{!"disableVertexComponentPacking", i1 false} !63 = !{!"disablePartialVertexComponentPacking", i1 false} !64 = !{!"PreferBindlessImages", i1 false} !65 = !{!"UseBindlessMode", i1 false} !66 = !{!"UseLegacyBindlessMode", i1 true} !67 = !{!"disableMathRefactoring", i1 false} !68 = !{!"atomicBranch", i1 false} !69 = !{!"ForceInt32DivRemEmu", i1 false} !70 = !{!"ForceInt32DivRemEmuSP", i1 false} !71 = !{!"DisableFastestSingleCSSIMD", i1 false} !72 = !{!"DisableFastestLinearScan", i1 false} !73 = !{!"UseStatelessforPrivateMemory", i1 false} !74 = !{!"EnableTakeGlobalAddress", i1 false} !75 = !{!"IsLibraryCompilation", i1 false} !76 = !{!"FastVISACompile", i1 false} !77 = !{!"MatchSinCosPi", i1 false} !78 = !{!"ExcludeIRFromZEBinary", i1 false} !79 = !{!"EmitZeBinVISASections", i1 false} !80 = !{!"FP64GenEmulationEnabled", i1 false} !81 = !{!"allowDisableRematforCS", i1 false} !82 = !{!"DisableIncSpillCostAllAddrTaken", i1 false} !83 = !{!"DisableCPSOmaskWA", i1 false} !84 = !{!"DisableFastestGopt", i1 false} !85 = !{!"WaForceHalfPromotionComputeShader", i1 false} !86 = !{!"WaForceHalfPromotionPixelVertexShader", i1 false} !87 = !{!"DisableConstantCoalescing", i1 false} !88 = !{!"EnableUndefAlphaOutputAsRed", i1 true} !89 = !{!"WaEnableALTModeVisaWA", i1 false} !90 = !{!"NewSpillCostFunction", i1 false} !91 = !{!"ForceLargeGRFNum4RQ", i1 false} !92 = !{!"DisableEUFusion", i1 false} !93 = !{!"DisableFDivToFMulInvOpt", i1 false} !94 = !{!"initializePhiSampleSourceWA", i1 false} !95 = !{!"WaDisableSubspanUseNoMaskForCB", i1 false} !96 = !{!"FastestS1Options", i32 0} !97 = !{!"FuncMD", !98, !99} !98 = !{!"FuncMDMap[0]", void (half addrspace(1)*, half addrspace(1)*, half addrspace(1)*, i32, i32, i32, i32, i32, i32, i8 addrspace(3)*, <8 x i32>, <8 x i32>, i16, i16, i16, i8*, i32, i32, i32)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c} !99 = !{!"FuncMDValue[0]", !100, !104, !108, !109, !110, !131, !159, !160, !161, !162, !163, !164, !165, !166, !167, !168, !169, !170, !171, !172, !173, !174, !185, !196, !207, !218, !229, !240, !241} !100 = !{!"localOffsets", !101} !101 = !{!"localOffsetsVec[0]", !102, !103} !102 = !{!"m_Offset", i32 0} !103 = !{!"m_Var", [0 x i8] addrspace(3)* @_kernel_0d1d2d3d4d5d6d7c8d9c10d11c-ExtSLM} !104 = !{!"workGroupWalkOrder", !105, !106, !107} !105 = !{!"dim0", i32 0} !106 = !{!"dim1", i32 0} !107 = !{!"dim2", i32 0} !108 = !{!"funcArgs"} !109 = !{!"functionType", !"KernelFunction"} !110 = !{!"rtInfo", !111, !112, !113, !114, !115, !116, !117, !118, !119, !120, !121, !122, !123, !124, !125, !126, !130} !111 = !{!"callableShaderType", !"NumberOfCallableShaderTypes"} !112 = !{!"isContinuation", i1 false} !113 = !{!"hasTraceRayPayload", i1 false} !114 = !{!"hasHitAttributes", i1 false} !115 = !{!"hasCallableData", i1 false} !116 = !{!"ShaderStackSize", i32 0} !117 = !{!"ShaderHash", i64 0} !118 = !{!"ShaderName", !""} !119 = !{!"ParentName", !""} !120 = !{!"SlotNum", i1* null} !121 = !{!"NOSSize", i32 0} !122 = !{!"globalRootSignatureSize", i32 0} !123 = !{!"Entries"} !124 = !{!"SpillUnions"} !125 = !{!"CustomHitAttrSizeInBytes", i32 0} !126 = !{!"Types", !127, !128, !129} !127 = !{!"FrameStartTys"} !128 = !{!"ArgumentTys"} !129 = !{!"FullFrameTys"} !130 = !{!"Aliases"} !131 = !{!"resAllocMD", !132, !133, !134, !135, !158} !132 = !{!"uavsNumType", i32 0} !133 = !{!"srvsNumType", i32 0} !134 = !{!"samplersNumType", i32 0} !135 = !{!"argAllocMDList", !136, !140, !141, !142, !143, !144, !145, !146, !147, !148, !149, !150, !151, !152, !153, !154, !155, !156, !157} !136 = !{!"argAllocMDListVec[0]", !137, !138, !139} !137 = !{!"type", i32 0} !138 = !{!"extensionType", i32 -1} !139 = !{!"indexType", i32 -1} !140 = !{!"argAllocMDListVec[1]", !137, !138, !139} !141 = !{!"argAllocMDListVec[2]", !137, !138, !139} !142 = !{!"argAllocMDListVec[3]", !137, !138, !139} !143 = !{!"argAllocMDListVec[4]", !137, !138, !139} !144 = !{!"argAllocMDListVec[5]", !137, !138, !139} !145 = !{!"argAllocMDListVec[6]", !137, !138, !139} !146 = !{!"argAllocMDListVec[7]", !137, !138, !139} !147 = !{!"argAllocMDListVec[8]", !137, !138, !139} !148 = !{!"argAllocMDListVec[9]", !137, !138, !139} !149 = !{!"argAllocMDListVec[10]", !137, !138, !139} !150 = !{!"argAllocMDListVec[11]", !137, !138, !139} !151 = !{!"argAllocMDListVec[12]", !137, !138, !139} !152 = !{!"argAllocMDListVec[13]", !137, !138, !139} !153 = !{!"argAllocMDListVec[14]", !137, !138, !139} !154 = !{!"argAllocMDListVec[15]", !137, !138, !139} !155 = !{!"argAllocMDListVec[16]", !137, !138, !139} !156 = !{!"argAllocMDListVec[17]", !137, !138, !139} !157 = !{!"argAllocMDListVec[18]", !137, !138, !139} !158 = !{!"inlineSamplersMD"} !159 = !{!"maxByteOffsets"} !160 = !{!"IsInitializer", i1 false} !161 = !{!"IsFinalizer", i1 false} !162 = !{!"CompiledSubGroupsNumber", i32 0} !163 = !{!"hasInlineVmeSamplers", i1 false} !164 = !{!"localSize", i32 0} !165 = !{!"localIDPresent", i1 true} !166 = !{!"groupIDPresent", i1 true} !167 = !{!"privateMemoryPerWI", i32 0} !168 = !{!"globalIDPresent", i1 false} !169 = !{!"hasSyncRTCalls", i1 false} !170 = !{!"hasNonKernelArgLoad", i1 false} !171 = !{!"hasNonKernelArgStore", i1 false} !172 = !{!"hasNonKernelArgAtomic", i1 false} !173 = !{!"UserAnnotations"} !174 = !{!"m_OpenCLArgAddressSpaces", !175, !176, !177, !178, !179, !180, !181, !182, !183, !184} !175 = !{!"m_OpenCLArgAddressSpacesVec[0]", i32 1} !176 = !{!"m_OpenCLArgAddressSpacesVec[1]", i32 1} !177 = !{!"m_OpenCLArgAddressSpacesVec[2]", i32 1} !178 = !{!"m_OpenCLArgAddressSpacesVec[3]", i32 0} !179 = !{!"m_OpenCLArgAddressSpacesVec[4]", i32 0} !180 = !{!"m_OpenCLArgAddressSpacesVec[5]", i32 0} !181 = !{!"m_OpenCLArgAddressSpacesVec[6]", i32 0} !182 = !{!"m_OpenCLArgAddressSpacesVec[7]", i32 0} !183 = !{!"m_OpenCLArgAddressSpacesVec[8]", i32 0} !184 = !{!"m_OpenCLArgAddressSpacesVec[9]", i32 3} !185 = !{!"m_OpenCLArgAccessQualifiers", !186, !187, !188, !189, !190, !191, !192, !193, !194, !195} !186 = !{!"m_OpenCLArgAccessQualifiersVec[0]", !"none"} !187 = !{!"m_OpenCLArgAccessQualifiersVec[1]", !"none"} !188 = !{!"m_OpenCLArgAccessQualifiersVec[2]", !"none"} !189 = !{!"m_OpenCLArgAccessQualifiersVec[3]", !"none"} !190 = !{!"m_OpenCLArgAccessQualifiersVec[4]", !"none"} !191 = !{!"m_OpenCLArgAccessQualifiersVec[5]", !"none"} !192 = !{!"m_OpenCLArgAccessQualifiersVec[6]", !"none"} !193 = !{!"m_OpenCLArgAccessQualifiersVec[7]", !"none"} !194 = !{!"m_OpenCLArgAccessQualifiersVec[8]", !"none"} !195 = !{!"m_OpenCLArgAccessQualifiersVec[9]", !"none"} !196 = !{!"m_OpenCLArgTypes", !197, !198, !199, !200, !201, !202, !203, !204, !205, !206} !197 = !{!"m_OpenCLArgTypesVec[0]", !"half*"} !198 = !{!"m_OpenCLArgTypesVec[1]", !"half*"} !199 = !{!"m_OpenCLArgTypesVec[2]", !"half*"} !200 = !{!"m_OpenCLArgTypesVec[3]", !"int"} !201 = !{!"m_OpenCLArgTypesVec[4]", !"int"} !202 = !{!"m_OpenCLArgTypesVec[5]", !"int"} !203 = !{!"m_OpenCLArgTypesVec[6]", !"int"} !204 = !{!"m_OpenCLArgTypesVec[7]", !"int"} !205 = !{!"m_OpenCLArgTypesVec[8]", !"int"} !206 = !{!"m_OpenCLArgTypesVec[9]", !"char*"} !207 = !{!"m_OpenCLArgBaseTypes", !208, !209, !210, !211, !212, !213, !214, !215, !216, !217} !208 = !{!"m_OpenCLArgBaseTypesVec[0]", !"half*"} !209 = !{!"m_OpenCLArgBaseTypesVec[1]", !"half*"} !210 = !{!"m_OpenCLArgBaseTypesVec[2]", !"half*"} !211 = !{!"m_OpenCLArgBaseTypesVec[3]", !"int"} !212 = !{!"m_OpenCLArgBaseTypesVec[4]", !"int"} !213 = !{!"m_OpenCLArgBaseTypesVec[5]", !"int"} !214 = !{!"m_OpenCLArgBaseTypesVec[6]", !"int"} !215 = !{!"m_OpenCLArgBaseTypesVec[7]", !"int"} !216 = !{!"m_OpenCLArgBaseTypesVec[8]", !"int"} !217 = !{!"m_OpenCLArgBaseTypesVec[9]", !"char*"} !218 = !{!"m_OpenCLArgTypeQualifiers", !219, !220, !221, !222, !223, !224, !225, !226, !227, !228} !219 = !{!"m_OpenCLArgTypeQualifiersVec[0]", !""} !220 = !{!"m_OpenCLArgTypeQualifiersVec[1]", !""} !221 = !{!"m_OpenCLArgTypeQualifiersVec[2]", !""} !222 = !{!"m_OpenCLArgTypeQualifiersVec[3]", !""} !223 = !{!"m_OpenCLArgTypeQualifiersVec[4]", !""} !224 = !{!"m_OpenCLArgTypeQualifiersVec[5]", !""} !225 = !{!"m_OpenCLArgTypeQualifiersVec[6]", !""} !226 = !{!"m_OpenCLArgTypeQualifiersVec[7]", !""} !227 = !{!"m_OpenCLArgTypeQualifiersVec[8]", !""} !228 = !{!"m_OpenCLArgTypeQualifiersVec[9]", !""} !229 = !{!"m_OpenCLArgNames", !230, !231, !232, !233, !234, !235, !236, !237, !238, !239} !230 = !{!"m_OpenCLArgNamesVec[0]", !""} !231 = !{!"m_OpenCLArgNamesVec[1]", !""} !232 = !{!"m_OpenCLArgNamesVec[2]", !""} !233 = !{!"m_OpenCLArgNamesVec[3]", !""} !234 = !{!"m_OpenCLArgNamesVec[4]", !""} !235 = !{!"m_OpenCLArgNamesVec[5]", !""} !236 = !{!"m_OpenCLArgNamesVec[6]", !""} !237 = !{!"m_OpenCLArgNamesVec[7]", !""} !238 = !{!"m_OpenCLArgNamesVec[8]", !""} !239 = !{!"m_OpenCLArgNamesVec[9]", !""} !240 = !{!"m_OpenCLArgScalarAsPointers"} !241 = !{!"m_OptsToDisablePerFunc"} !242 = !{!"pushInfo", !243, !244, !245, !248, !249, !250, !251, !252, !253, !254, !255, !268, !269, !270, !271} !243 = !{!"pushableAddresses"} !244 = !{!"bindlessPushInfo"} !245 = !{!"dynamicBufferInfo", !246, !247} !246 = !{!"firstIndex", i32 0} !247 = !{!"numOffsets", i32 0} !248 = !{!"MaxNumberOfPushedBuffers", i32 0} !249 = !{!"inlineConstantBufferSlot", i32 -1} !250 = !{!"inlineConstantBufferOffset", i32 -1} !251 = !{!"inlineConstantBufferGRFOffset", i32 -1} !252 = !{!"constants"} !253 = !{!"inputs"} !254 = !{!"constantReg"} !255 = !{!"simplePushInfoArr", !256, !265, !266, !267} !256 = !{!"simplePushInfoArrVec[0]", !257, !258, !259, !260, !261, !262, !263, !264} !257 = !{!"cbIdx", i32 0} !258 = !{!"pushableAddressGrfOffset", i32 -1} !259 = !{!"pushableOffsetGrfOffset", i32 -1} !260 = !{!"offset", i32 0} !261 = !{!"size", i32 0} !262 = !{!"isStateless", i1 false} !263 = !{!"isBindless", i1 false} !264 = !{!"simplePushLoads"} !265 = !{!"simplePushInfoArrVec[1]", !257, !258, !259, !260, !261, !262, !263, !264} !266 = !{!"simplePushInfoArrVec[2]", !257, !258, !259, !260, !261, !262, !263, !264} !267 = !{!"simplePushInfoArrVec[3]", !257, !258, !259, !260, !261, !262, !263, !264} !268 = !{!"simplePushBufferUsed", i32 0} !269 = !{!"pushAnalysisWIInfos"} !270 = !{!"inlineRTGlobalPtrOffset", i32 0} !271 = !{!"rtSyncSurfPtrOffset", i32 0} !272 = !{!"psInfo", !273, !274, !275, !276, !277, !278, !279, !280, !281, !282, !283, !284, !285, !286, !287} !273 = !{!"BlendStateDisabledMask", i8 0} !274 = !{!"SkipSrc0Alpha", i1 false} !275 = !{!"DualSourceBlendingDisabled", i1 false} !276 = !{!"ForceEnableSimd32", i1 false} !277 = !{!"outputDepth", i1 false} !278 = !{!"outputStencil", i1 false} !279 = !{!"outputMask", i1 false} !280 = !{!"blendToFillEnabled", i1 false} !281 = !{!"forceEarlyZ", i1 false} !282 = !{!"hasVersionedLoop", i1 false} !283 = !{!"forceSingleSourceRTWAfterDualSourceRTW", i1 false} !284 = !{!"NumSamples", i8 0} !285 = !{!"blendOptimizationMode"} !286 = !{!"colorOutputMask"} !287 = !{!"WaDisableVRS", i1 false} !288 = !{!"csInfo", !289, !290, !291, !292, !293, !38, !39, !294, !295, !296, !297, !298, !299, !300, !301, !68, !302, !303, !304} !289 = !{!"maxWorkGroupSize", i32 0} !290 = !{!"waveSize", i32 0} !291 = !{!"ComputeShaderSecondCompile"} !292 = !{!"forcedSIMDSize", i8 0} !293 = !{!"forceTotalGRFNum", i32 0} !294 = !{!"allowLowerSimd", i1 false} !295 = !{!"disableSimd32Slicing", i1 false} !296 = !{!"disableSplitOnSpill", i1 false} !297 = !{!"forcedVISAPreRAScheduler", i1 false} !298 = !{!"disableLocalIdOrderOptimizations", i1 false} !299 = !{!"disableDispatchAlongY", i1 false} !300 = !{!"neededThreadIdLayout", i1* null} !301 = !{!"forceTileYWalk", i1 false} !302 = !{!"walkOrderEnabled", i1 false} !303 = !{!"walkOrderOverride", i32 0} !304 = !{!"ResForHfPacking"} !305 = !{!"msInfo", !306, !307, !308, !309, !310, !311, !312, !313, !314} !306 = !{!"PrimitiveTopology", i32 3} !307 = !{!"MaxNumOfPrimitives", i32 0} !308 = !{!"MaxNumOfVertices", i32 0} !309 = !{!"MaxNumOfPerPrimitiveOutputs", i32 0} !310 = !{!"MaxNumOfPerVertexOutputs", i32 0} !311 = !{!"WorkGroupSize", i32 0} !312 = !{!"WorkGroupMemorySizeInBytes", i32 0} !313 = !{!"IndexFormat", i32 6} !314 = !{!"SubgroupSize", i32 0} !315 = !{!"taskInfo", !316, !311, !312, !314} !316 = !{!"MaxNumOfOutputs", i32 0} !317 = !{!"NBarrierCnt", i32 0} !318 = !{!"rtInfo", !319, !320, !321, !322, !323, !324, !325, !326, !327, !328, !329, !330} !319 = !{!"RayQueryAllocSizeInBytes", i32 0} !320 = !{!"NumContinuations", i32 0} !321 = !{!"RTAsyncStackAddrspace", i32 -1} !322 = !{!"RTAsyncStackSurfaceStateOffset", i1* null} !323 = !{!"SWHotZoneAddrspace", i32 -1} !324 = !{!"SWHotZoneSurfaceStateOffset", i1* null} !325 = !{!"SWStackAddrspace", i32 -1} !326 = !{!"SWStackSurfaceStateOffset", i1* null} !327 = !{!"RTSyncStackAddrspace", i32 -1} !328 = !{!"RTSyncStackSurfaceStateOffset", i1* null} !329 = !{!"doSyncDispatchRays", i1 false} !330 = !{!"MemStyle", !"Xe"} !331 = !{!"CurUniqueIndirectIdx", i32 0} !332 = !{!"inlineDynTextures"} !333 = !{!"inlineResInfoData"} !334 = !{!"immConstant", !335, !336, !337} !335 = !{!"data"} !336 = !{!"sizes"} !337 = !{!"zeroIdxs"} !338 = !{!"stringConstants"} !339 = !{!"inlineConstantBuffers"} !340 = !{!"inlineGlobalBuffers"} !341 = !{!"GlobalPointerProgramBinaryInfos"} !342 = !{!"ConstantPointerProgramBinaryInfos"} !343 = !{!"GlobalBufferAddressRelocInfo"} !344 = !{!"ConstantBufferAddressRelocInfo"} !345 = !{!"forceLscCacheList"} !346 = !{!"SrvMap"} !347 = !{!"RasterizerOrderedByteAddressBuffer"} !348 = !{!"RasterizerOrderedViews"} !349 = !{!"MinNOSPushConstantSize", i32 0} !350 = !{!"inlineProgramScopeOffsets"} !351 = !{!"shaderData", !352} !352 = !{!"numReplicas", i32 0} !353 = !{!"URBInfo", !354, !355, !356} !354 = !{!"has64BVertexHeaderInput", i1 false} !355 = !{!"has64BVertexHeaderOutput", i1 false} !356 = !{!"hasVertexHeader", i1 true} !357 = !{!"UseBindlessImage", i1 false} !358 = !{!"enableRangeReduce", i1 false} !359 = !{!"allowMatchMadOptimizationforVS", i1 false} !360 = !{!"disableMatchMadOptimizationForCS", i1 false} !361 = !{!"disableMemOptforNegativeOffsetLoads", i1 false} !362 = !{!"enableThreeWayLoadSpiltOpt", i1 false} !363 = !{!"statefulResourcesNotAliased", i1 false} !364 = !{!"disableMixMode", i1 false} !365 = !{!"genericAccessesResolved", i1 false} !366 = !{!"PrivateMemoryPerFG"} !367 = !{!"m_OptsToDisable"} !368 = !{!"capabilities", !369} !369 = !{!"globalVariableDecorationsINTEL", i1 false} !370 = !{!"m_ShaderResourceViewMcsMask", !371, !372} !371 = !{!"m_ShaderResourceViewMcsMaskVec[0]", i64 0} !372 = !{!"m_ShaderResourceViewMcsMaskVec[1]", i64 0} !373 = !{!"computedDepthMode", i32 0} !374 = !{!"isHDCFastClearShader", i1 false} !375 = !{i32 2, i32 0} !376 = !{!"clang version 14.0.5"} !377 = distinct !DISubprogram(name: "_kernel_0d1d2d3d4d5d6d7c8d9c10d11c", linkageName: "_kernel_0d1d2d3d4d5d6d7c8d9c10d11c", scope: null, file: !4, line: 88, type: !378, scopeLine: 88, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !3, templateParams: !380, retainedNodes: !380) !378 = !DISubroutineType(types: !379) !379 = !{null} !380 = !{} !381 = !DILocation(line: 88, scope: !377) !382 = !DILocation(line: 44, column: 22, scope: !383, inlinedAt: !385) !383 = !DILexicalBlockFile(scope: !377, file: !384, discriminator: 0) !384 = !DIFile(filename: "standard.py", directory: "/home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/language") !385 = distinct !DILocation(line: 101, scope: !377) !386 = !DILocation(line: 44, column: 28, scope: !383, inlinedAt: !385) !387 = !DILocation(line: 44, column: 22, scope: !383, inlinedAt: !388) !388 = distinct !DILocation(line: 102, scope: !377) !389 = !DILocation(line: 44, column: 28, scope: !383, inlinedAt: !388) !390 = !DILocation(line: 104, column: 22, scope: !377) !391 = !DILocation(line: 105, column: 22, scope: !377) !392 = !DILocation(line: 106, column: 41, scope: !377) !393 = !DILocation(line: 106, column: 30, scope: !377) !394 = !DILocation(line: 106, column: 50, scope: !377) !395 = !DILocation(line: 107, column: 40, scope: !377) !396 = !DILocation(line: 107, column: 34, scope: !377) !397 = !DILocation(line: 108, column: 30, scope: !377) !398 = !DILocation(line: 110, column: 17, scope: !377) !399 = !DILocation(line: 110, column: 40, scope: !377) !400 = !DILocation(line: 110, column: 27, scope: !377) !401 = !DILocation(line: 111, column: 17, scope: !377) !402 = !DILocation(line: 111, column: 27, scope: !377) !403 = !DILocation(line: 44, column: 22, scope: !383, inlinedAt: !404) !404 = distinct !DILocation(line: 119, scope: !377) !405 = !DILocation(line: 44, column: 28, scope: !383, inlinedAt: !404) !406 = !DILocation(line: 119, column: 22, scope: !377) !407 = !DILocation(line: 136, column: 33, scope: !377) !408 = !DILocation(line: 114, column: 17, scope: !377) !409 = !DILocation(line: 114, column: 27, scope: !377) !410 = !DILocation(line: 117, column: 27, scope: !377) !411 = !DILocation(line: 113, column: 48, scope: !377) !412 = !DILocation(line: 117, column: 39, scope: !377) !413 = !DILocation(line: 117, column: 13, scope: !377) !414 = !DILocation(line: 112, column: 48, scope: !377) !415 = !DILocation(line: 116, column: 28, scope: !377) !416 = !DILocation(line: 116, column: 40, scope: !377) !417 = !DILocation(line: 116, column: 13, scope: !377) !418 = !DILocation(line: 132, column: 31, scope: !377) !419 = !DILocation(line: 121, column: 24, scope: !377) !420 = !DILocation(line: 122, column: 24, scope: !377) !421 = !DILocation(line: 135, column: 13, scope: !377) !422 = !DILocation(line: 136, column: 13, scope: !377) !423 = !DILocation(line: 137, column: 17, scope: !377) !424 = !DILocation(line: 145, column: 20, scope: !377) !425 = !DILocation(line: 142, column: 17, scope: !377) !426 = !DILocation(line: 142, column: 37, scope: !377) !427 = !DILocation(line: 142, column: 31, scope: !377) !428 = !DILocation(line: 141, column: 27, scope: !377) !429 = !DILocation(line: 141, column: 39, scope: !377) !430 = !DILocation(line: 141, column: 13, scope: !377) !431 = !DILocation(line: 144, column: 4, scope: !377) ```
Gen ISA assemble ``` //.kernel _kernel_0d1d2d3d4d5d6d7c8d9c10d11c //.platform PVCXT //.thread_config numGRF=128, numAcc=4, numSWSB=16 //.options_string "-emitCrossThreadOffR0Reloc " //.full_options "-emitLocation -enableCoalesceScalarMoves -hasRNEandDenorm -noStitchExternFunc -emitCrossThreadOffR0Reloc -linker 63 -preserver0 -abortOnSpill 4 -enableBundleCR 3 -boundsChecking -presched-rp 100 -nodpsendreorder -SBIDDepLoc -output -binary -dumpcommonisa -dumpcombinedcisa -dumpvisa -printHexFloatInAsm -noverifyCISA -enableHalfLSC -partialInt64 -generateDebugInfo " //.instCount 525 //.RA type GRAPH_COLORING_FF_BC_RA //.git-hash 80055716ef5675b711e3f3198c61b0ff7dc7208d //.declare BuiltInR0 (0) rf=r size=64 type=ud align=32 words (r0.0) IsBuiltin //.declare (1) rf=r size=64 type=ud alias=BuiltInR0+0 align=32 words (r0.0) IsBuiltin //.declare BuiltinA0 (2) rf=a size=4 type=ud align=1 words (a0.0) IsBuiltin //.declare BuiltinA0Dot2 (3) rf=a size=4 type=ud align=1 words (a0.2) IsBuiltin //.declare %null (9) rf=r size=4 type=ud align=32 words //.declare %local_id_x (12) rf=r size=4 type=ud align=2 words (r1.12) //.declare %local_id_y (13) rf=r size=4 type=ud align=2 words (r1.13) //.declare %local_size_x (14) rf=r size=4 type=ud align=2 words (r1.8) //.declare %local_size_y (15) rf=r size=4 type=ud align=2 words (r1.9) //.declare %group_id_x (16) rf=r size=4 type=ud align=2 words (r0.1) //.declare %group_id_y (17) rf=r size=4 type=ud align=2 words (r0.6) //.declare %group_id_z (18) rf=r size=4 type=ud align=2 words (r0.7) //.declare %group_count_x (19) rf=r size=4 type=ud align=2 words (r1.10) //.declare %group_count_y (20) rf=r size=4 type=ud align=2 words (r1.11) //.declare %tsc (21) rf=r size=20 type=ud align=2 words //.declare %arg (22) rf=r size=0 type=ud align=32 words (r26.0) //.declare %retval (23) rf=r size=0 type=ud align=32 words (r26.0) Output //.declare %sp (24) rf=r size=8 type=uq align=32 words (r125.3) //.declare %fp (25) rf=r size=8 type=uq align=32 words (r125.2) //.declare %sr0 (26) rf=r size=16 type=ud align=2 words //.declare %cr0 (27) rf=r size=12 type=ud align=2 words //.declare %ce0 (28) rf=r size=4 type=ud align=2 words //.declare %dbg0 (29) rf=r size=8 type=ud align=2 words //.declare implBufPtr (31) rf=r size=8 type=uq align=32 words (r126.0) //.declare localIdBufPtr (32) rf=r size=8 type=uq align=32 words (r126.3) //.declare %msg0 (33) rf=r size=12 type=ud align=2 words //.declare V0033 (41) rf=r size=64 type=d alias=+0 align=32 words (r0.0) //.declare V0034 (42) rf=r size=8 type=uq align=4 words (r4.5) //.declare V0035 (43) rf=r size=8 type=uq align=4 words (r4.6) //.declare V0036 (44) rf=r size=8 type=uq align=4 words (r4.7) //.declare V0037 (45) rf=r size=4 type=d align=2 words (r5.0) //.declare V0038 (46) rf=r size=4 type=d align=2 words (r5.1) //.declare V0039 (47) rf=r size=4 type=d align=2 words (r5.2) //.declare V0040 (48) rf=r size=4 type=d align=2 words (r5.3) //.declare V0041 (49) rf=r size=4 type=d align=2 words (r5.4) //.declare V0042 (50) rf=r size=4 type=d align=2 words (r5.5) //.declare V0044 (52) rf=r size=32 type=d alias=+0 align=32 words (r0.0) //.declare V0045 (53) rf=r size=32 type=d align=16 words (r4.0) //.declare V0046 (54) rf=r size=64 type=w align=32 words (r1.0) //.declare V0047 (55) rf=r size=64 type=w align=32 words (r2.0) //.declare V0048 (56) rf=r size=64 type=w align=32 words (r3.0) //.declare V0049 (57) rf=r size=8 type=uq align=4 words (r5.3) //.declare V0053 (61) rf=r size=4 type=d align=2 words (r5.6) //.declare P01 (62) rf=f32 size=4 type=uw align=2 words (f2.0) //.declare V0054 (63) rf=r size=4 type=d align=2 words (r2.9) //.declare V0055 (64) rf=r size=4 type=d align=2 words (r5.7) //.declare P02 (65) rf=f32 size=4 type=uw align=2 words (f1.0) //.declare V0056 (66) rf=r size=4 type=d align=2 words (r5.6) //.declare V0057 (67) rf=r size=4 type=d align=2 words (r2.7) //.declare P03 (68) rf=f32 size=4 type=uw align=2 words (f0.0) //.declare V0058 (69) rf=r size=4 type=d align=2 words (r2.6) //.declare V0059 (70) rf=r size=4 type=d align=2 words (r2.3) //.declare V0060 (71) rf=r size=4 type=d align=2 words (r2.2) //.declare V0061 (72) rf=r size=4 type=d align=2 words (r5.6) //.declare V0062 (73) rf=r size=4 type=d align=2 words (r2.0) //.declare V0063 (74) rf=r size=4 type=d align=2 words (r5.6) //.declare V0064 (75) rf=r size=4 type=d align=2 words (r2.8) //.declare V0065 (76) rf=r size=4 type=f align=2 words (r5.6) //.declare V0066 (77) rf=r size=4 type=ud alias=V0062+0 align=2 words (r2.0) //.declare V0067 (78) rf=r size=4 type=f align=2 words (r2.4) //.declare V0068 (79) rf=r size=8 type=df align=4 words (r5.3) //.declare V0069 (80) rf=r size=8 type=df align=4 words (r2.0) //.declare V0070 (81) rf=r size=8 type=df align=4 words (r5.4) //.declare V0071 (82) rf=r size=8 type=df align=4 words (r2.2) //.declare V0072 (83) rf=r size=8 type=df align=4 words (r2.0) //.declare V0073 (84) rf=r size=4 type=ud alias=V0064+0 align=2 words (r2.8) //.declare V0074 (85) rf=r size=8 type=df align=4 words (r2.0) //.declare V0075 (86) rf=r size=8 type=df align=4 words (r2.0) //.declare V0076 (87) rf=r size=4 type=d align=4 words (r2.0) //.declare V0077 (88) rf=r size=4 type=ud alias=V0076+0 align=2 words (r2.0) //.declare V0078 (89) rf=r size=4 type=d align=2 words (r5.6) //.declare V0079 (90) rf=r size=4 type=d align=2 words (r5.6) //.declare V0080 (91) rf=r size=4 type=d align=2 words (r2.13) //.declare V0081 (92) rf=r size=4 type=d align=2 words (r5.6) //.declare V0082 (93) rf=r size=4 type=d alias=+0 align=2 words (r2.4) //.declare P04 (94) rf=f32 size=4 type=uw align=2 words (f2.0) //.declare V0083 (95) rf=r size=4 type=d align=2 words (r2.12) //.declare V0084 (96) rf=r size=4 type=d align=2 words (r2.2) //.declare V0085 (97) rf=r size=4 type=d align=2 words (r2.0) //.declare V0086 (98) rf=r size=4 type=d align=2 words (r5.6) //.declare V0087 (99) rf=r size=4 type=d align=2 words (r2.3) //.declare V0088 (100) rf=r size=4 type=f align=2 words (r5.6) //.declare V0089 (101) rf=r size=4 type=ud alias=V0085+0 align=2 words (r2.0) //.declare V0090 (102) rf=r size=4 type=f align=2 words (r2.1) //.declare V0091 (103) rf=r size=8 type=df align=4 words (r5.3) //.declare V0092 (104) rf=r size=8 type=df align=4 words (r2.4) //.declare V0093 (105) rf=r size=8 type=df align=4 words (r2.0) //.declare V0094 (106) rf=r size=8 type=df align=4 words (r2.5) //.declare V0095 (107) rf=r size=8 type=df align=4 words (r2.4) //.declare V0096 (108) rf=r size=4 type=ud alias=V0087+0 align=2 words (r2.3) //.declare V0097 (109) rf=r size=8 type=df align=4 words (r2.0) //.declare V0098 (110) rf=r size=8 type=df align=4 words (r2.0) //.declare V0099 (111) rf=r size=4 type=d align=4 words (r2.8) //.declare V0100 (112) rf=r size=4 type=ud alias=V0099+0 align=2 words (r2.8) //.declare V0101 (113) rf=r size=4 type=d align=32 words (r2.0) //.declare V0102 (114) rf=r size=4 type=d align=2 words (r5.6) //.declare V0103 (115) rf=r size=4 type=d align=2 words (r2.9) //.declare V0104 (116) rf=r size=4 type=d align=2 words (r5.8) //.declare V0105 (117) rf=r size=4 type=d align=32 words (r3.0) //.declare V0106 (118) rf=r size=4 type=d alias=+4 align=2 words (r2.5) //.declare V0107 (119) rf=r size=4 type=d alias=+0 align=2 words (r2.0) //.declare V0108 (120) rf=r size=4 type=d alias=+4 align=2 words (r2.1) //.declare V0109 (121) rf=r size=4 type=d align=2 words (r5.6) //.declare V0110 (122) rf=r size=4 type=d align=2 words (r2.2) //.declare V0111 (123) rf=r size=4 type=d align=2 words (r5.6) //.declare V0112 (124) rf=r size=4 type=d align=2 words (r2.8) //.declare V0113 (125) rf=r size=4 type=f align=2 words (r5.6) //.declare V0114 (126) rf=r size=4 type=ud alias=V0110+0 align=2 words (r2.2) //.declare V0115 (127) rf=r size=4 type=f align=2 words (r2.3) //.declare V0116 (128) rf=r size=8 type=df align=4 words (r5.3) //.declare V0117 (129) rf=r size=8 type=df align=4 words (r2.2) //.declare V0118 (130) rf=r size=8 type=df align=4 words (r2.1) //.declare V0119 (131) rf=r size=8 type=df align=4 words (r2.3) //.declare V0120 (132) rf=r size=8 type=df align=4 words (r2.2) //.declare V0121 (133) rf=r size=4 type=ud alias=V0112+0 align=2 words (r2.8) //.declare V0122 (134) rf=r size=8 type=df align=4 words (r2.1) //.declare V0123 (135) rf=r size=8 type=df align=4 words (r2.1) //.declare V0124 (136) rf=r size=4 type=d align=4 words (r2.2) //.declare V0125 (137) rf=r size=4 type=ud alias=V0124+0 align=2 words (r2.2) //.declare V0126 (138) rf=r size=4 type=d align=2 words (r5.6) //.declare V0127 (139) rf=r size=4 type=d align=2 words (r5.6) //.declare V0128 (140) rf=r size=4 type=d align=2 words (r5.6) //.declare V0129 (141) rf=r size=4 type=d align=2 words (r5.13) //.declare V0130 (142) rf=r size=64 type=w align=32 words (r8.0) //.declare V0131 (143) rf=r size=64 type=uw alias=V0046+0 align=32 words (r1.0) //.declare V0132 (144) rf=r size=128 type=d align=32 words (r6.0) //.declare V0133 (145) rf=r size=128 type=d align=32 words (r62.0) //.declare V0134 (146) rf=r size=128 type=d align=32 words (r2.0) //.declare V0135 (147) rf=r size=128 type=d align=32 words (r12.0) //.declare V0136 (148) rf=r size=64 type=uw alias=V0130+0 align=32 words (r8.0) //.declare V0137 (149) rf=r size=128 type=d align=32 words (r64.0) //.declare V0138 (150) rf=r size=128 type=d align=32 words (r116.0) //.declare V0139 (151) rf=r size=4 type=d align=2 words (r8.0) //.declare V0140 (152) rf=r size=4 type=d align=2 words (r5.7) //.declare P05 (153) rf=f32 size=4 type=uw align=2 words (f3.0) //.declare V0141 (154) rf=r size=4 type=d align=2 words (r5.6) //.declare P06 (157) rf=f32 size=4 type=uw align=2 words (f2.0) //.declare V0144 (158) rf=r size=64 type=hf align=32 words (r54.0) //.declare V0145 (159) rf=r size=64 type=hf align=32 words (r55.0) //.declare V0146 (160) rf=r size=64 type=hf align=32 words (r56.0) //.declare V0147 (161) rf=r size=64 type=hf align=32 words (r57.0) //.declare V0148 (162) rf=r size=64 type=hf align=32 words (r58.0) //.declare V0149 (163) rf=r size=64 type=hf align=32 words (r59.0) //.declare V0150 (164) rf=r size=64 type=hf align=32 words (r60.0) //.declare V0151 (165) rf=r size=64 type=hf align=32 words (r61.0) //.declare V0152 (166) rf=r size=4 type=d align=2 words (r5.2) //.declare V0153 (167) rf=r size=4 type=d align=2 words (r5.7) //.declare V0154 (168) rf=r size=128 type=d align=32 words (r10.0) //.declare V0155 (169) rf=r size=128 type=d align=32 words (r8.0) //.declare P07 (170) rf=f32 size=4 type=uw align=2 words (f1.0) //.declare V0156 (171) rf=r size=128 type=d align=32 words (r10.0) //.declare V0157 (172) rf=r size=4 type=d align=2 words (r5.12) //.declare V0158 (173) rf=r size=4 type=d align=2 words (r5.8) //.declare V0159 (174) rf=r size=128 type=d align=32 words (r12.0) //.declare V0160 (175) rf=r size=128 type=d align=32 words (r10.0) //.declare V0161 (176) rf=r size=4 type=f align=2 words (r5.4) //.declare V0162 (177) rf=r size=4 type=ud alias=V0158+0 align=2 words (r5.8) //.declare V0163 (178) rf=r size=4 type=f align=2 words (r5.4) //.declare V0164 (179) rf=r size=8 type=df align=4 words (r5.4) //.declare V0165 (180) rf=r size=8 type=df align=4 words (r12.0) //.declare V0166 (181) rf=r size=8 type=df align=4 words (r5.5) //.declare V0167 (182) rf=r size=8 type=df align=4 words (r22.0) //.declare V0169 (184) rf=r size=128 type=ud alias=V0160+0 align=32 words (r10.0) //.declare V0172 (187) rf=r size=128 type=d align=32 words (r14.0) //.declare V0173 (188) rf=r size=128 type=ud alias=V0172+0 align=32 words (r14.0) //.declare V0174 (189) rf=r size=128 type=d align=32 words (r16.0) //.declare V0175 (190) rf=r size=128 type=d align=32 words (r12.0) //.declare V0176 (191) rf=r size=128 type=d align=32 words (r12.0) //.declare V0177 (192) rf=r size=8 type=q alias=V0035+0 align=32 words (r4.6) //.declare V0180 (195) rf=r size=256 type=q align=32 words (r22.0) //.declare P08 (196) rf=f32 size=4 type=uw align=2 words (f0.0) //.declare V0181 (197) rf=r size=128 type=d align=32 words (r8.0) //.declare V0182 (198) rf=r size=4 type=d align=2 words (r5.12) //.declare V0183 (199) rf=r size=4 type=d align=2 words (r5.8) //.declare V0184 (200) rf=r size=128 type=d align=32 words (r10.0) //.declare V0185 (201) rf=r size=128 type=d align=32 words (r8.0) //.declare V0186 (202) rf=r size=4 type=f align=2 words (r5.4) //.declare V0187 (203) rf=r size=4 type=ud alias=V0183+0 align=2 words (r5.8) //.declare V0188 (204) rf=r size=4 type=f align=2 words (r5.4) //.declare V0189 (205) rf=r size=8 type=df align=4 words (r5.4) //.declare V0190 (206) rf=r size=8 type=df align=4 words (r10.0) //.declare V0191 (207) rf=r size=8 type=df align=4 words (r5.5) //.declare V0192 (208) rf=r size=8 type=df align=4 words (r20.0) //.declare V0194 (210) rf=r size=128 type=ud alias=V0185+0 align=32 words (r8.0) //.declare V0197 (213) rf=r size=128 type=d align=32 words (r12.0) //.declare V0198 (214) rf=r size=128 type=ud alias=V0197+0 align=32 words (r12.0) //.declare V0199 (215) rf=r size=128 type=d align=32 words (r14.0) //.declare V0200 (216) rf=r size=128 type=d align=32 words (r10.0) //.declare V0201 (217) rf=r size=4 type=d align=2 words (r5.6) //.declare V0202 (218) rf=r size=128 type=d align=32 words (r10.0) //.declare V0203 (219) rf=r size=128 type=d align=32 words (r8.0) //.declare V0204 (220) rf=r size=8 type=q alias=V0034+0 align=32 words (r4.5) //.declare V0207 (223) rf=r size=256 type=q align=32 words (r34.0) //.declare V0209 (225) rf=r size=128 type=d align=32 words (r8.0) //.declare V0210 (226) rf=r size=128 type=d align=32 words (r2.0) //.declare V0211 (227) rf=r size=128 type=d align=32 words (r110.0) //.declare V0212 (228) rf=r size=128 type=d align=32 words (r108.0) //.declare V0213 (229) rf=r size=128 type=d align=32 words (r106.0) //.declare V0214 (230) rf=r size=64 type=w align=32 words (r1.0) //.declare V0215 (231) rf=r size=64 type=w align=32 words (r10.0) //.declare V0217 (233) rf=r size=64 type=uw alias=V0214+0 align=32 words (r1.0) //.declare V0218 (234) rf=r size=128 type=d align=32 words (r8.0) //.declare V0219 (235) rf=r size=128 type=d align=32 words (r104.0) //.declare V0220 (236) rf=r size=128 type=d align=32 words (r102.0) //.declare V0221 (237) rf=r size=128 type=d align=32 words (r100.0) //.declare V0222 (238) rf=r size=128 type=ud align=32 words (r98.0) //.declare V0223 (239) rf=r size=64 type=uw alias=V0215+0 align=32 words (r10.0) //.declare V0224 (240) rf=r size=128 type=ud alias=V0219+0 align=32 words (r104.0) //.declare V0225 (241) rf=r size=128 type=ud alias=V0220+0 align=32 words (r102.0) //.declare V0226 (242) rf=r size=128 type=ud alias=V0221+0 align=32 words (r100.0) //.declare V0227 (243) rf=r size=64 type=w align=32 words (r1.0) //.declare V0228 (244) rf=r size=128 type=ud align=32 words (r96.0) //.declare V0229 (245) rf=r size=64 type=uw alias=V0227+0 align=32 words (r1.0) //.declare V0230 (246) rf=r size=128 type=d align=32 words (r94.0) //.declare V0231 (247) rf=r size=128 type=ud alias=V0230+0 align=32 words (r94.0) //.declare V0232 (248) rf=r size=128 type=d align=32 words (r92.0) //.declare V0233 (249) rf=r size=128 type=ud alias=V0232+0 align=32 words (r92.0) //.declare V0234 (250) rf=r size=128 type=d align=32 words (r90.0) //.declare V0235 (251) rf=r size=128 type=ud alias=V0234+0 align=32 words (r90.0) //.declare V0236 (252) rf=r size=128 type=d align=32 words (r12.0) //.declare V0237 (253) rf=r size=128 type=d align=32 words (r10.0) //.declare V0238 (254) rf=r size=128 type=d align=32 words (r8.0) //.declare V0239 (255) rf=r size=4 type=d align=2 words (r5.3) //.declare V0240 (256) rf=r size=128 type=d align=32 words (r6.0) //.declare V0241 (257) rf=r size=128 type=d align=32 words (r88.0) //.declare V0242 (258) rf=r size=128 type=ud alias=V0241+0 align=32 words (r88.0) //.declare V0243 (259) rf=r size=128 type=d align=32 words (r86.0) //.declare V0244 (260) rf=r size=128 type=ud alias=V0243+0 align=32 words (r86.0) //.declare V0245 (261) rf=r size=128 type=d align=32 words (r84.0) //.declare V0246 (262) rf=r size=128 type=ud alias=V0245+0 align=32 words (r84.0) //.declare V0247 (263) rf=r size=128 type=d align=32 words (r82.0) //.declare V0248 (264) rf=r size=128 type=ud alias=V0247+0 align=32 words (r82.0) //.declare V0249 (265) rf=r size=128 type=d align=32 words (r80.0) //.declare V0250 (266) rf=r size=128 type=ud alias=V0249+0 align=32 words (r80.0) //.declare V0251 (267) rf=r size=128 type=d align=32 words (r78.0) //.declare V0252 (268) rf=r size=128 type=ud alias=V0251+0 align=32 words (r78.0) //.declare V0253 (269) rf=r size=128 type=d align=32 words (r76.0) //.declare V0254 (270) rf=r size=128 type=ud alias=V0253+0 align=32 words (r76.0) //.declare V0255 (271) rf=r size=128 type=d align=32 words (r74.0) //.declare V0256 (272) rf=r size=128 type=ud alias=V0255+0 align=32 words (r74.0) //.declare V0257 (273) rf=r size=128 type=d align=32 words (r72.0) //.declare V0258 (274) rf=r size=128 type=d align=32 words (r70.0) //.declare V0259 (275) rf=r size=128 type=d align=32 words (r68.0) //.declare V0260 (276) rf=r size=128 type=d align=32 words (r66.0) //.declare V0262 (278) rf=r size=8 type=q align=4 words (r5.1) //.declare V0263 (279) rf=r size=128 type=f align=32 words (r52.0) //.declare V0264 (280) rf=r size=128 type=f align=32 words (r50.0) //.declare V0265 (281) rf=r size=128 type=f align=32 words (r48.0) //.declare V0266 (282) rf=r size=128 type=f align=32 words (r46.0) //.declare V0267 (283) rf=r size=128 type=f align=32 words (r44.0) //.declare V0268 (284) rf=r size=128 type=f align=32 words (r42.0) //.declare V0269 (285) rf=r size=128 type=f align=32 words (r40.0) //.declare V0270 (286) rf=r size=128 type=f align=32 words (r38.0) //.declare V0271 (287) rf=r size=4 type=d align=2 words (r5.4) //.declare V0272 (288) rf=r size=256 type=uq alias=V0207+0 align=32 words (r34.0) //.declare V0273 (289) rf=r size=512 type=d align=32 words (r6.0) //.declare V0276 (292) rf=r size=512 type=ud alias=V0273+0 align=32 words (r6.0) //.declare (303) rf=r size=64 type=ud align=32 words (r1.0) //.declare (304) rf=r size=32 type=ud align=32 words (r1.0) //.declare V0289 (307) rf=r size=128 type=ud alias=V0210+0 align=32 words (r2.0) //.declare V0290 (308) rf=r size=128 type=d align=32 words (r26.0) //.declare V0291 (309) rf=r size=128 type=w alias=V0290+0 align=32 words (r26.0) //.declare V0292 (310) rf=r size=128 type=ud alias=V0257+0 align=32 words (r72.0) //.declare V0293 (311) rf=r size=128 type=d align=32 words (r20.0) //.declare V0294 (312) rf=r size=128 type=w alias=V0293+0 align=32 words (r20.0) //.declare V0297 (315) rf=r size=128 type=ud alias=V0211+0 align=32 words (r110.0) //.declare V0298 (316) rf=r size=128 type=d align=32 words (r18.0) //.declare V0299 (317) rf=r size=128 type=w alias=V0298+0 align=32 words (r18.0) //.declare V0300 (318) rf=r size=128 type=ud alias=V0258+0 align=32 words (r70.0) //.declare V0301 (319) rf=r size=128 type=d align=32 words (r16.0) //.declare V0302 (320) rf=r size=128 type=w alias=V0301+0 align=32 words (r16.0) //.declare V0303 (321) rf=r size=256 type=uq alias=V0180+0 align=32 words (r22.0) //.declare V0304 (322) rf=r size=512 type=d align=32 words (r6.0) //.declare V0307 (325) rf=r size=512 type=ud alias=V0304+0 align=32 words (r6.0) //.declare V0320 (338) rf=r size=128 type=ud alias=V0212+0 align=32 words (r108.0) //.declare V0321 (339) rf=r size=128 type=d align=32 words (r26.0) //.declare V0322 (340) rf=r size=128 type=w alias=V0321+0 align=32 words (r26.0) //.declare V0323 (341) rf=r size=128 type=ud alias=V0259+0 align=32 words (r68.0) //.declare V0324 (342) rf=r size=128 type=d align=32 words (r20.0) //.declare V0325 (343) rf=r size=128 type=w alias=V0324+0 align=32 words (r20.0) //.declare V0328 (346) rf=r size=128 type=ud alias=V0213+0 align=32 words (r106.0) //.declare V0329 (347) rf=r size=128 type=d align=32 words (r18.0) //.declare V0330 (348) rf=r size=128 type=w alias=V0329+0 align=32 words (r18.0) //.declare V0331 (349) rf=r size=128 type=ud alias=V0260+0 align=32 words (r66.0) //.declare V0332 (350) rf=r size=128 type=d align=32 words (r16.0) //.declare V0333 (351) rf=r size=128 type=w alias=V0332+0 align=32 words (r16.0) //.declare (352) rf=r size=64 type=ud align=32 words (r1.0) //.declare (353) rf=r size=32 type=ud align=32 words (r1.0) //.declare V0335 (355) rf=r size=128 type=ud align=32 words (r6.0) //.declare V0336 (356) rf=r size=128 type=w alias=V0335+0 align=32 words (r6.0) //.declare V0338 (358) rf=r size=128 type=ud align=32 words (r6.0) //.declare V0339 (359) rf=r size=128 type=w alias=V0338+0 align=32 words (r6.0) //.declare V0341 (361) rf=r size=128 type=ud align=32 words (r6.0) //.declare V0342 (362) rf=r size=128 type=w alias=V0341+0 align=32 words (r6.0) //.declare V0344 (364) rf=r size=128 type=ud align=32 words (r6.0) //.declare V0345 (365) rf=r size=128 type=w alias=V0344+0 align=32 words (r6.0) //.declare V0347 (367) rf=r size=128 type=ud align=32 words (r6.0) //.declare V0348 (368) rf=r size=128 type=w alias=V0347+0 align=32 words (r6.0) //.declare V0350 (370) rf=r size=128 type=ud align=32 words (r6.0) //.declare V0351 (371) rf=r size=128 type=w alias=V0350+0 align=32 words (r6.0) //.declare V0353 (373) rf=r size=128 type=ud align=32 words (r6.0) //.declare V0354 (374) rf=r size=128 type=w alias=V0353+0 align=32 words (r6.0) //.declare V0356 (376) rf=r size=128 type=ud align=32 words (r6.0) //.declare V0357 (377) rf=r size=128 type=w alias=V0356+0 align=32 words (r6.0) //.declare V0359 (379) rf=r size=128 type=ud align=32 words (r14.0) //.declare V0360 (380) rf=r size=128 type=hf alias=V0359+0 align=32 words (r14.0) //.declare V0362 (382) rf=r size=128 type=ud align=32 words (r14.0) //.declare V0363 (383) rf=r size=128 type=hf alias=V0362+0 align=32 words (r14.0) //.declare V0365 (385) rf=r size=128 type=ud align=32 words (r14.0) //.declare V0366 (386) rf=r size=128 type=hf alias=V0365+0 align=32 words (r14.0) //.declare V0368 (388) rf=r size=128 type=ud align=32 words (r14.0) //.declare V0369 (389) rf=r size=128 type=hf alias=V0368+0 align=32 words (r14.0) //.declare V0371 (391) rf=r size=128 type=ud align=32 words (r14.0) //.declare V0372 (392) rf=r size=128 type=hf alias=V0371+0 align=32 words (r14.0) //.declare V0374 (394) rf=r size=128 type=ud align=32 words (r14.0) //.declare V0375 (395) rf=r size=128 type=hf alias=V0374+0 align=32 words (r14.0) //.declare V0377 (397) rf=r size=128 type=ud align=32 words (r14.0) //.declare V0378 (398) rf=r size=128 type=hf alias=V0377+0 align=32 words (r14.0) //.declare V0380 (400) rf=r size=128 type=ud align=32 words (r14.0) //.declare V0381 (401) rf=r size=128 type=hf alias=V0380+0 align=32 words (r14.0) //.declare V0382 (402) rf=r size=256 type=w align=32 words (r58.0) //.declare V0383 (403) rf=r size=256 type=w align=32 words (r54.0) //.declare V0385 (405) rf=r size=512 type=d align=32 words (r6.0) //.declare V0386 (406) rf=r size=512 type=hf alias=V0385+0 align=32 words (r6.0) //.declare V0387 (407) rf=r size=512 type=f align=32 words (r26.0) //.declare V0388 (408) rf=r size=512 type=f align=32 words (r14.0) //.declare V0389 (409) rf=r size=512 type=f align=32 words (r26.0) //.declare V0390 (410) rf=r size=256 type=ud alias=V0382+0 align=32 words (r58.0) //.declare V0391 (411) rf=r size=512 type=f align=32 words (r14.0) //.declare V0392 (412) rf=r size=256 type=ud alias=V0383+0 align=32 words (r54.0) //.declare P09 (413) rf=f32 size=4 type=uw align=2 words (f3.0) //.declare V0395 (416) rf=r size=128 type=d align=32 words (r2.0) //.declare P10 (417) rf=f32 size=4 type=uw align=2 words (f3.0) //.declare P11 (418) rf=f32 size=4 type=uw align=2 words (f1.0) //.declare (419) rf=r size=64 type=ud align=32 words (r1.0) //.declare (420) rf=r size=32 type=ud align=32 words (r1.0) //.declare P12 (421) rf=f32 size=4 type=uw align=2 words (f0.0) //.declare V0397 (423) rf=r size=64 type=w align=32 words (r1.0) //.declare V0398 (424) rf=r size=128 type=d align=32 words (r2.0) //.declare V0399 (425) rf=r size=128 type=d align=32 words (r8.0) //.declare V0400 (426) rf=r size=128 type=d align=32 words (r6.0) //.declare V0401 (427) rf=r size=64 type=b align=32 words (r1.0) //.declare V0402 (428) rf=r size=128 type=d align=32 words (r8.0) //.declare V0403 (429) rf=r size=64 type=ub alias=V0401+0 align=32 words (r1.0) //.declare V0404 (430) rf=r size=128 type=d align=32 words (r10.0) //.declare V0406 (432) rf=r size=128 type=ud alias=V0404+0 align=32 words (r10.0) //.declare V0408 (434) rf=r size=128 type=ud align=32 words (r8.0) //.declare V0409 (435) rf=r size=64 type=w align=32 words (r1.0) //.declare V0410 (436) rf=r size=64 type=b align=32 words (r1.0) //.declare V0411 (437) rf=r size=128 type=d align=32 words (r8.0) //.declare V0412 (438) rf=r size=64 type=ub alias=V0410+0 align=32 words (r1.0) //.declare V0413 (439) rf=r size=128 type=d align=32 words (r10.0) //.declare V0415 (441) rf=r size=128 type=ud alias=V0413+0 align=32 words (r10.0) //.declare V0417 (443) rf=r size=128 type=ud align=32 words (r8.0) //.declare V0418 (444) rf=r size=64 type=w align=32 words (r1.0) //.declare V0419 (445) rf=r size=64 type=b align=32 words (r1.0) //.declare V0420 (446) rf=r size=128 type=d align=32 words (r8.0) //.declare V0421 (447) rf=r size=64 type=ub alias=V0419+0 align=32 words (r1.0) //.declare V0422 (448) rf=r size=128 type=d align=32 words (r10.0) //.declare V0424 (450) rf=r size=128 type=ud alias=V0422+0 align=32 words (r10.0) //.declare V0426 (452) rf=r size=128 type=ud align=32 words (r8.0) //.declare V0427 (453) rf=r size=128 type=d align=32 words (r8.0) //.declare V0428 (454) rf=r size=128 type=d align=32 words (r10.0) //.declare V0430 (456) rf=r size=128 type=ud alias=V0428+0 align=32 words (r10.0) //.declare V0432 (458) rf=r size=128 type=ud align=32 words (r8.0) //.declare V0433 (459) rf=r size=128 type=d align=32 words (r8.0) //.declare V0434 (460) rf=r size=128 type=d align=32 words (r10.0) //.declare V0436 (462) rf=r size=128 type=ud alias=V0434+0 align=32 words (r10.0) //.declare V0438 (464) rf=r size=128 type=ud align=32 words (r8.0) //.declare V0439 (465) rf=r size=128 type=d align=32 words (r8.0) //.declare V0440 (466) rf=r size=128 type=d align=32 words (r10.0) //.declare V0442 (468) rf=r size=128 type=ud alias=V0440+0 align=32 words (r10.0) //.declare V0444 (470) rf=r size=128 type=ud align=32 words (r8.0) //.declare V0445 (471) rf=r size=128 type=d align=32 words (r8.0) //.declare V0446 (472) rf=r size=128 type=d align=32 words (r10.0) //.declare V0448 (474) rf=r size=128 type=ud alias=V0446+0 align=32 words (r10.0) //.declare V0450 (476) rf=r size=128 type=ud align=32 words (r8.0) //.declare V0451 (477) rf=r size=128 type=d align=32 words (r8.0) //.declare V0452 (478) rf=r size=128 type=d align=32 words (r10.0) //.declare V0454 (480) rf=r size=128 type=ud alias=V0452+0 align=32 words (r10.0) //.declare V0456 (482) rf=r size=128 type=ud align=32 words (r6.0) //.declare (483) rf=r size=64 type=ud align=32 words (r1.0) //.declare (484) rf=r size=32 type=ud align=32 words (r1.0) //.declare V0457 (485) rf=r size=128 type=d align=32 words (r6.0) //.declare V0458 (486) rf=r size=128 type=ud alias=V0398+0 align=32 words (r2.0) //.declare V0459 (487) rf=r size=128 type=d align=32 words (r8.0) //.declare V0460 (488) rf=r size=128 type=d align=32 words (r2.0) //.declare V0461 (489) rf=r size=128 type=d align=32 words (r6.0) //.declare V0462 (490) rf=r size=128 type=d align=32 words (r2.0) //.declare V0463 (491) rf=r size=128 type=ud alias=V0462+0 align=32 words (r2.0) //.declare V0464 (492) rf=r size=512 type=d align=32 words (r9.0) //.declare V0465 (493) rf=r size=128 type=d align=32 words (r2.0) //.declare V0466 (494) rf=r size=128 type=d align=32 words (r6.0) //.declare V0467 (495) rf=r size=8 type=q alias=V0036+0 align=32 words (r4.7) //.declare V0470 (498) rf=r size=256 type=q align=32 words (r5.0) //.declare V0471 (499) rf=r size=256 type=uq alias=V0470+0 align=32 words (r5.0) //.declare V0472 (500) rf=r size=4 type=ud align=2 words (r4.8) //.declare (501) rf=r size=64 type=ud align=32 words (r112.0) //.declare (502) rf=r size=8 type=d align=8 words (r2.0) //.declare (503) rf=r size=8 type=d align=8 words (r2.4) //.declare (504) rf=r size=8 type=df align=4 words (r5.3) //.declare (505) rf=r size=8 type=df align=4 words (r5.3) //.declare (506) rf=r size=4 type=d align=2 words (r2.0) //.declare (507) rf=r size=8 type=df align=4 words (r5.3) //.declare (508) rf=r size=128 type=uw align=32 words (r8.0) //.declare (509) rf=r size=128 type=uw align=32 words (r8.0) //.declare (510) rf=r size=8 type=df align=4 words (r5.4) //.declare (511) rf=r size=128 type=ud align=32 words (r12.0) //.declare (512) rf=r size=128 type=ud align=32 words (r12.0) //.declare (513) rf=r size=4 type=d align=2 words (r5.4) //.declare (515) rf=r size=128 type=ud align=32 words (r12.0) //.declare (516) rf=r size=128 type=ud align=32 words (r12.0) //.declare (517) rf=r size=128 type=ud align=32 words (r8.0) //.declare (518) rf=r size=128 type=ud align=32 words (r8.0) //.declare (519) rf=r size=8 type=df align=4 words (r5.4) //.declare (520) rf=r size=128 type=ud align=32 words (r10.0) //.declare (521) rf=r size=128 type=ud align=32 words (r10.0) //.declare (522) rf=r size=4 type=d align=2 words (r5.4) //.declare (524) rf=r size=128 type=ud align=32 words (r10.0) //.declare (525) rf=r size=128 type=ud align=32 words (r10.0) //.declare (526) rf=r size=128 type=d align=32 words (r2.0) //.declare (527) rf=r size=128 type=d align=32 words (r2.0) //.declare (528) rf=r size=128 type=ud align=32 words (r2.0) //.declare (529) rf=r size=128 type=ud align=32 words (r2.0) //.declare (530) rf=r size=128 type=w align=32 words (r14.0) //.declare (531) rf=r size=128 type=w align=32 words (r14.0) //.declare (532) rf=r size=128 type=w align=32 words (r14.0) //.declare (533) rf=r size=128 type=w align=32 words (r14.0) //.declare (534) rf=r size=128 type=w align=32 words (r14.0) //.declare (535) rf=r size=128 type=w align=32 words (r14.0) //.declare (536) rf=r size=128 type=w align=32 words (r14.0) //.declare (537) rf=r size=128 type=w align=32 words (r14.0) //.declare (538) rf=r size=2 type=w align=1 words (r5.0) //.declare (541) rf=r size=2 type=w align=1 words (r5.0) //.declare (542) rf=r size=2 type=w align=1 words (r5.0) //.declare (543) rf=r size=2 type=w align=1 words (r5.0) //.declare (544) rf=r size=2 type=w align=1 words (r5.0) //.declare (545) rf=r size=2 type=w align=1 words (r5.0) //.declare (546) rf=r size=2 type=w align=1 words (r5.0) //.declare (547) rf=r size=2 type=w align=1 words (r5.0) //.declare (548) rf=r size=128 type=ud align=32 words (r1.0) //.declare (549) rf=r size=128 type=ud align=32 words (r1.0) //.declare (550) rf=r size=128 type=q align=32 words (r114.0) //.declare (551) rf=r size=128 type=q align=32 words (r112.0) //.declare (552) rf=r size=128 type=df align=32 words (r19.0) //.declare (553) rf=r size=128 type=df align=32 words (r17.0) //.declare (554) rf=r size=128 type=df align=32 words (r15.0) //.declare (555) rf=r size=128 type=df align=32 words (r13.0) //.declare (556) rf=r size=128 type=df align=32 words (r17.0) //.declare (557) rf=r size=128 type=df align=32 words (r19.0) //.declare (560) rf=r size=128 type=q align=32 words (r14.0) //.declare (561) rf=r size=128 type=q align=32 words (r10.0) //.declare (562) rf=r size=128 type=df align=32 words (r17.0) //.declare (563) rf=r size=128 type=df align=32 words (r15.0) //.declare (564) rf=r size=128 type=df align=32 words (r13.0) //.declare (565) rf=r size=128 type=df align=32 words (r11.0) //.declare (566) rf=r size=128 type=df align=32 words (r15.0) //.declare (567) rf=r size=128 type=df align=32 words (r17.0) //.declare (570) rf=r size=128 type=q align=32 words (r10.0) //.declare (571) rf=r size=128 type=q align=32 words (r8.0) //.declare (576) rf=r size=128 type=q align=32 words (r19.0) //.declare (577) rf=r size=128 type=q align=32 words (r17.0) //.declare (578) rf=r size=128 type=d alias=+0 align=32 words (r114.0) //.declare (579) rf=r size=128 type=d alias=+0 align=32 words (r112.0) //.declare r0 (646) rf=r size=64 type=ud align=32 words (r0.0) //.declare rtmp (647) rf=r size=64 type=ud align=32 words (r127.0) //.declare (648) rf=r size=128 type=ud align=32 words (r1.0) //.declare (649) rf=r size=4 type=ud align=2 words (r126.0) //.declare (650) rf=r size=64 type=ud align=32 words (r3.0) //.declare (651) rf=r size=64 type=ud align=32 words (r4.0) //.declare (652) rf=r size=4 type=ud align=2 words (r126.0) //.declare (653) rf=r size=32 type=ud align=2 words (r5.0) // .inputs // +----------+----------+--------+----------+------------------+ // | id | type | bytes | at | from | // +----------+----------+--------+----------+------------------+ // | V0046 | :w x 32 | 0x40 | r1 | pti[tid]+0x0 | // | V0047 | :w x 32 | 0x40 | r2 | pti[tid]+0x40 | // | V0048 | :w x 32 | 0x40 | r3 | pti[tid]+0x80 | // | V0045 | :d x 8 | 0x20 | r4 | cti+0x0 | // | V0472 | :ud | 0x4 | r4+0x20 | cti+0x20 | // | V0034 | :uq | 0x8 | r4+0x28 | cti+0x28 | // | V0035 | :uq | 0x8 | r4+0x30 | cti+0x30 | // | V0036 | :uq | 0x8 | r4+0x38 | cti+0x38 | // | V0037 | :d | 0x4 | r5 | cti+0x40 | // | V0038 | :d | 0x4 | r5+0x4 | cti+0x44 | // | V0039 | :d | 0x4 | r5+0x8 | cti+0x48 | // | V0040 | :d | 0x4 | r5+0xC | cti+0x4C | // | V0041 | :d | 0x4 | r5+0x10 | cti+0x50 | // | V0042 | :d | 0x4 | r5+0x14 | cti+0x54 | // | V0049 | :uq | 0x8 | r5+0x18 | cti+0x58 | // +----------+----------+--------+----------+------------------+ // B000: Preds:{}, Succs:{B001} per_thread_prolog: (W) mov (16|M0) r127.0<1>:ud 0x0:ud // ALU pipe: int; (W) and (1|M0) r127.2<1>:ud r0.0<0;1,0>:ud 0xFFFFFFC0:ud // ALU pipe: int; (W) and (1|M0) r127.0<1>:uw r0.4<0;1,0>:uw 0xFF:uw // ALU pipe: int; (W) add (1|M0) r127.2<1>:ud r127.2<0;1,0>:ud 0x60:ud {I@2} // ALU pipe: int; (W) add (1|M0) r127.2<1>:ud r127.2<0;1,0>:ud 0x0:ud {I@1} // R_SYM_ADDR_32: __INTEL_PATCH_CROSS_THREAD_OFFSET_OFF_R0; ALU pipe: int; (W) mad (1|M0) r127.0<1>:ud r127.2<0;0>:ud r127.0<0;0>:uw 0xC0:uw {I@1} // ALU pipe: int; // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/ops/matmul.py // Line 144: if SPLIT_K == 1: (W) load.ugm.d32x32t.a32.ca.ca (1|M0) r1:2 bti[255][r127:1] {A@1,$0} // ex_desc:0xFF000000; desc:0x6228E500 // (W) add (1|M0) r126.0<1>:ud r127.0<0;1,0>:ud 0x80:uw // ALU pipe: int; // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/ops/matmul.py // Line 144: if SPLIT_K == 1: (W) load.ugm.d32x16t.a32.ca.ca (1|M0) r3:1 bti[255][r126:1] {A@1,$1} // ex_desc:0xFF000000; desc:0x6218D500 // nop // nop // nop // // B001: Preds:{B000}, Succs:{B002} // cross_thread_prolog: (W) and (1|M0) r127.0<1>:ud r0.0<0;1,0>:ud 0xFFFFFFC0:ud {$0.src} // ALU pipe: int; (W) add (1|M0) r127.0<1>:ud r127.0<0;1,0>:ud 0x0:ud {I@1} // R_SYM_ADDR_32: __INTEL_PATCH_CROSS_THREAD_OFFSET_OFF_R0; ALU pipe: int; // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/ops/matmul.py // Line 144: if SPLIT_K == 1: (W) load.ugm.d32x16t.a32.ca.ca (1|M0) r4:1 bti[255][r127:1] {I@1,$2} // ex_desc:0xFF000000; desc:0x6218D500 // (W) add (1|M0) r126.0<1>:ud r127.0<0;1,0>:ud 0x40:uw {$1.src} // ALU pipe: int; (W) load.ugm.d32x8t.a32.ca.ca (1|M0) r5:1 bti[255][r126:1] {I@1,$3} // ex_desc:0xFF000000; desc:0x6218C500 // // B002: Preds:{B001}, Succs:{B003, B004} // _main: (W) or (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x4C0:uw {Compacted,A@1} // $1 // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/language/standard.py // Line 44: return (x + div - 1) // div (W) add (1|M0) r5.6<1>:d r5.0<0;1,0>:d 15:w {A@1,$3.dst} // ALU pipe: int; $4 (W) cmp (32|M0) (lt)f2.0 null<1>:d r5.6<0;1,0>:d 0:w {I@1} // ALU pipe: int; $5 (W&~f2.0) jmpi _0_034 // ALU pipe: int; $6 // B003: Preds:{B002}, Succs:{B004} _0_035: (W) add (1|M0) r5.6<1>:d r5.0<0;1,0>:d 30:w // ALU pipe: int; $8 // B004: Preds:{B003, B002}, Succs:{B005, B006} _0_034: (W) add (1|M0) r5.7<1>:d r5.1<0;1,0>:d 15:w // ALU pipe: int; $11 sync.nop null {Compacted,I@2} // $10 (W) asr (1|M0) r2.9<1>:d r5.6<0;1,0>:d 4:w {$0.dst} // ALU pipe: int; $10 (W) cmp (32|M0) (lt)f1.0 null<1>:d r5.7<0;1,0>:d 0:w {I@2} // ALU pipe: int; $12 (W&~f1.0) jmpi _0_036 // ALU pipe: int; $13 // B005: Preds:{B004}, Succs:{B006} _0_037: (W) add (1|M0) r5.7<1>:d r5.1<0;1,0>:d 30:w // ALU pipe: int; $15 // B006: Preds:{B005, B004}, Succs:{B007, B008} _0_036: // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/ops/matmul.py // Line 104: width = GROUP_M * grid_n (W) asr (1|M0) r5.6<1>:d r5.7<0;1,0>:d 1:w {I@1} // ALU pipe: int; $19 (W) and (1|M0) r2.7<1>:d r5.6<0;1,0>:d -8:w {I@1} // ALU pipe: int; $20 // Line 105: group_id = pid // width (W) cmp (32|M0) (eq)f0.0 null<1>:d r2.7<0;1,0>:d 0:w {I@1} // ALU pipe: int; $22 (W&~f0.0) jmpi _0_038 // ALU pipe: int; $23 // B007: Preds:{B006}, Succs:{B009} _0_039: (W) mov (1|M0) r2.6<1>:d -1:w // ALU pipe: int; $25 (W) jmpi _0_040 // $26 // B008: Preds:{B006}, Succs:{B009} _0_038: (W) asr (1|M0) r2.3<1>:d r5.7<0;1,0>:d 31:w // ALU pipe: int; $28 (W) asr (1|M0) r2.2<1>:d r0.1<0;1,0>:d 31:w // ALU pipe: int; $29 (W) add (1|M0) r5.6<1>:d r2.3<0;1,0>:d r2.7<0;1,0>:d {I@2} // ALU pipe: int; $30 (W) xor (1|M0) r2.0<1>:d r5.6<0;1,0>:d r2.3<0;1,0>:d {Compacted,I@1} // ALU pipe: int; $31 (W) add (1|M0) r5.6<1>:d r2.2<0;1,0>:d r0.1<0;1,0>:d // ALU pipe: int; $32 (W) xor (1|M0) r2.8<1>:d r5.6<0;1,0>:d r2.2<0;1,0>:d {I@1} // ALU pipe: int; $33 (W) mov (1|M0) r5.6<1>:f r2.0<0;1,0>:ud {I@1} // ALU pipe: float; $34 (W) math.inv (1|M0) r2.4<1>:f r5.6<0;1,0>:f {F@1} // ALU pipe: math; $35 (W) mov (1|M0) r5.3<1>:df r2.0<0;1,0>:ud {M@1} // ALU pipe: long; $36 (W) mov (1|M0) r5.4<1>:df r2.4<0;1,0>:f // ALU pipe: long; $38 (W) mov (1|M0) r2.0<1>:df -r5.3<0;1,0>:df {L@2} // ALU pipe: long; $37 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $39 (W) mov (1|M0) r5.3<1>:df 0x3FF0000000004000:df {A@1} // ALU pipe: long; $40 (W) mad (1|M0) r2.2<1>:df r5.3<0;0>:df r5.4<0;0>:df r2.0<0>:df {L@1} // ALU pipe: long; $40 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $41 (W) mov (1|M0) r2.0<1>:df r2.8<0;1,0>:ud {A@1} // ALU pipe: long; $42 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $43 (W) mul (1|M0) r2.0<1>:df r5.4<0;1,0>:df r2.0<0;1,0>:df {Compacted,A@1} // ALU pipe: long; $44 (W) mad (1|M0) r2.0<1>:df r2.0<0;0>:df r2.2<0;0>:df r2.0<0>:df {Compacted,L@1} // ALU pipe: long; $45 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $46 (W) mov (1|M0) r2.0<2>:ud r2.0<0;1,0>:df {A@1} // ALU pipe: int; $47 (W) xor (1|M0) r5.6<1>:d r2.3<0;1,0>:d r2.2<0;1,0>:d // ALU pipe: int; $48 (W) add (1|M0) r5.6<1>:d r5.6<0;1,0>:d r2.0<0;1,0>:d {I@1} // ALU pipe: int; $49 (W) bfn.(s0^s1^s2) (1|M0) r2.6<1>:ud r5.6<0;0>:ud r2.3<0;0>:ud r2.2<0>:ud {I@1} // ALU pipe: int; $50 // B009: Preds:{B008, B007}, Succs:{B010, B011} _0_040: // Line 106: group_size = min(grid_m - group_id * GROUP_M, GROUP_M) (W) shl (1|M0) r2.13<1>:d r2.6<0;1,0>:d 3:w {I@1} // ALU pipe: int; $53 (W) add (1|M0) r5.6<1>:d r2.9<0;1,0>:d -r2.13<0;1,0>:d {I@1} // ALU pipe: int; $54 (W) sel (1|M0) (lt)f0.0 r2.4<1>:d r5.6<0;1,0>:d 8:w {I@1} // ALU pipe: int; $55 // Line 107: pid_m = group_id * GROUP_M + (pid % group_size) (W) cmp (32|M0) (eq)f2.0 null<1>:d r2.4<0;1,0>:d 0:w {I@1} // ALU pipe: int; $57 (W&~f2.0) jmpi _0_041 // ALU pipe: int; $58 // B010: Preds:{B009}, Succs:{B012} _0_042: (W) mov (1|M0) r2.12<1>:d -1:w // ALU pipe: int; $60 (W) jmpi _0_043 // $61 // B011: Preds:{B009}, Succs:{B012} _0_041: (W) asr (1|M0) r2.2<1>:d r0.1<0;1,0>:d 31:w // ALU pipe: int; $63 (W) mov (1|M0) r2.0<1>:d (abs)r2.4<0;1,0>:d // ALU pipe: int; $64 (W) add (1|M0) r5.6<1>:d r2.2<0;1,0>:d r0.1<0;1,0>:d {I@2} // ALU pipe: int; $65 (W) xor (1|M0) r2.3<1>:d r5.6<0;1,0>:d r2.2<0;1,0>:d {I@1} // ALU pipe: int; $66 (W) mov (1|M0) r5.6<1>:f r2.0<0;1,0>:ud {I@1} // ALU pipe: float; $67 (W) math.inv (1|M0) r2.1<1>:f r5.6<0;1,0>:f {F@1} // ALU pipe: math; $68 (W) mov (1|M0) r5.3<1>:df r2.0<0;1,0>:ud {M@1} // ALU pipe: long; $69 (W) mov (1|M0) r2.0<1>:df r2.1<0;1,0>:f // ALU pipe: long; $71 (W) mov (1|M0) r2.4<1>:df -r5.3<0;1,0>:df {L@2} // ALU pipe: long; $70 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $72 (W) mov (1|M0) r5.3<1>:df 0x3FF0000000004000:df {A@1} // ALU pipe: long; $73 (W) mad (1|M0) r2.5<1>:df r5.3<0;0>:df r2.0<0;0>:df r2.4<0>:df {L@1} // ALU pipe: long; $73 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $74 (W) mov (1|M0) r2.4<1>:df r2.3<0;1,0>:ud {A@1} // ALU pipe: long; $75 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $76 (W) mul (1|M0) r2.0<1>:df r2.0<0;1,0>:df r2.4<0;1,0>:df {Compacted,A@1} // ALU pipe: long; $77 (W) mad (1|M0) r2.0<1>:df r2.0<0;0>:df r2.5<0;0>:df r2.0<0>:df {L@1} // ALU pipe: long; $78 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $79 (W) mov (1|M0) r2.8<2>:ud r2.0<0;1,0>:df {A@1} // ALU pipe: int; $80 (W) mov (1|M0) r2.0<1>:d (abs)r2.4<0;1,0>:d // ALU pipe: int; $81 (W) mul (1|M0) acc0.0<1>:d r2.0<0;1,0>:d r2.16<0;1,0>:uw {Compacted,I@1} // ALU pipe: int; $81 (W) macl (1|M0) r2.0<1>:d r2.0<0;1,0>:d r2.8<0;1,0>:d {Compacted} // ALU pipe: int; $82 (W) add3 (1|M0) r5.6<1>:d r2.3<0;0>:d r2.2<0;0>:d -r2.0<0>:d {I@1} // ALU pipe: int; $82 (W) xor (1|M0) r2.12<1>:d r5.6<0;1,0>:d r2.2<0;1,0>:d {I@1} // ALU pipe: int; $83 // B012: Preds:{B011, B010}, Succs:{B013, B014} _0_043: (W) add (1|M0) r2.9<1>:d r2.13<0;1,0>:d r2.12<0;1,0>:d {I@1} // ALU pipe: int; $85 // Line 108: pid_n = (pid % width) // (group_size) (W&~f2.0) jmpi _0_044 // ALU pipe: int; $87 // B013: Preds:{B012}, Succs:{B015} _0_045: (W) mov (1|M0) r5.8<1>:d -16:w {Compacted} // ALU pipe: int; $89 (W) jmpi _0_046 // $90 // B014: Preds:{B012}, Succs:{B015} _0_044: (W) mul (1|M0) acc0.0<1>:d r2.6<0;1,0>:d r2.14<0;1,0>:uw // ALU pipe: int; $92 (W) macl (1|M0) r3.0<1>:d r2.6<0;1,0>:d r2.7<0;1,0>:d {$1.dst} // ALU pipe: int; $93 (W) add (1|M0) r2.5<1>:d r0.1<0;1,0>:d -r3.0<0;1,0>:d {I@1} // ALU pipe: int; $93 (W) asr (2|M0) r2.0<1>:d r2.4<1;1,0>:d 31:w {Compacted,I@1} // ALU pipe: int; $94 (W) add (1|M0) r5.6<1>:d r2.0<0;1,0>:d r2.4<0;1,0>:d {I@1} // ALU pipe: int; $96 (W) xor (1|M0) r2.2<1>:d r5.6<0;1,0>:d r2.0<0;1,0>:d {I@1} // ALU pipe: int; $97 (W) add3 (1|M0) r5.6<1>:d r2.1<0;0>:d r0.1<0;0>:d -r3.0<0>:d // ALU pipe: int; $98 (W) xor (1|M0) r2.8<1>:d r5.6<0;1,0>:d r2.1<0;1,0>:d {I@1} // ALU pipe: int; $99 (W) mov (1|M0) r5.6<1>:f r2.2<0;1,0>:ud {I@1} // ALU pipe: float; $100 (W) math.inv (1|M0) r2.3<1>:f r5.6<0;1,0>:f {F@1} // ALU pipe: math; $101 (W) mov (1|M0) r5.3<1>:df r2.2<0;1,0>:ud {M@1} // ALU pipe: long; $102 (W) mov (1|M0) r2.1<1>:df r2.3<0;1,0>:f // ALU pipe: long; $104 (W) mov (1|M0) r2.2<1>:df -r5.3<0;1,0>:df {L@2} // ALU pipe: long; $103 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $105 (W) mov (1|M0) r5.3<1>:df 0x3FF0000000004000:df {A@1} // ALU pipe: long; $106 (W) mad (1|M0) r2.3<1>:df r5.3<0;0>:df r2.1<0;0>:df r2.2<0>:df {L@1} // ALU pipe: long; $106 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $107 (W) mov (1|M0) r2.2<1>:df r2.8<0;1,0>:ud {A@1} // ALU pipe: long; $108 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $109 (W) mul (1|M0) r2.1<1>:df r2.1<0;1,0>:df r2.2<0;1,0>:df {Compacted,A@1} // ALU pipe: long; $110 (W) mad (1|M0) r2.1<1>:df r2.1<0;0>:df r2.3<0;0>:df r2.1<0>:df {L@1} // ALU pipe: long; $111 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $112 (W) mov (1|M0) r2.2<2>:ud r2.1<0;1,0>:df {A@1} // ALU pipe: int; $113 (W) xor (1|M0) r5.6<1>:d r2.0<0;1,0>:d r2.1<0;1,0>:d // ALU pipe: int; $114 (W) add (1|M0) r5.6<1>:d r5.6<0;1,0>:d r2.2<0;1,0>:d {I@1} // ALU pipe: int; $115 (W) bfn.(s0^s1^s2) (1|M0) r5.6<1>:ud r5.6<0;0>:ud r2.0<0;0>:ud r2.1<0>:ud {I@1} // ALU pipe: int; $116 (W) shl (1|M0) r5.8<1>:d r5.6<0;1,0>:d 4:w {I@1} // ALU pipe: int; $117 // B015: Preds:{B014, B013}, Succs:{B016, B017} _0_046: // Line 110: rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) shr (32|M0) acc0.0<1>:w r1.0<1;1,0>:uw 1:w // ALU pipe: int; $121 mov (32|M0) r6.0<1>:d r1.0<1;1,0>:uw // ALU pipe: int; $123 and (32|M0) r8.0<1>:w acc0.0<1;1,0>:w 15:w // ALU pipe: int; $122 // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/language/standard.py // Line 44: return (x + div - 1) // div (W) add (1|M0) r5.7<1>:d r5.2<0;1,0>:d 15:w // ALU pipe: int; $133 // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/ops/matmul.py // Line 110: rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) shl (32|M0) r62.0<1>:d r6.0<1;1,0>:d 3:w {Compacted,I@3} // ALU pipe: int; $124 mov (32|M0) r12.0<1>:d r8.0<1;1,0>:uw {I@3} // ALU pipe: int; $126 (W) shl (1|M0) r5.13<1>:d r2.9<0;1,0>:d 4:w // ALU pipe: int; $120 // Line 111: rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) (W) mov (1|M0) r8.0<1>:d 8:w {Compacted} // ALU pipe: int; $129 // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/language/standard.py // Line 44: return (x + div - 1) // div (W) cmp (32|M0) (lt)f3.0 null<1>:d r5.7<0;1,0>:d 0:w {I@5} // ALU pipe: int; $134 // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/ops/matmul.py // Line 110: rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) sync.nop null {Compacted,I@5} // $125 and (32|M0) r2.0<1>:d r62.0<1;1,0>:d 8:w {Compacted,$1.dst} // ALU pipe: int; $125 or (32|M0) r64.0<1>:d r5.13<0;1,0>:d r12.0<1;1,0>:d {I@4} // ALU pipe: int; $127 // Line 111: rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) bfn.(s0&s1|s2) (32|M0) r116.0<1>:ud r62.0<1;0>:ud r8.0<0;0>:ud r5.8<0>:ud {I@4} // ALU pipe: int; $130 // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/language/standard.py // Line 44: return (x + div - 1) // div (W&f3.0) jmpi _0_047 // ALU pipe: int; $135 // B016: Preds:{B015}, Succs:{B018} _0_048: (W) mov (1|M0) r5.6<1>:d r5.7<0;1,0>:d // ALU pipe: int; $137 (W) jmpi _0_049 // $138 // B017: Preds:{B015}, Succs:{B018} _0_047: (W) add (1|M0) r5.6<1>:d r5.2<0;1,0>:d 30:w // ALU pipe: int; $140 // B018: Preds:{B017, B016}, Succs:{B019, B020} _0_049: mov (16|M0) r8.0<4>:uw r1.0<1;1,0>:uw // ALU pipe: int; $143 // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/ops/matmul.py // Line 119: for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): (W) cmp (32|M0) (gt)f2.0 null<1>:d r5.7<0;1,0>:d 15:w // ALU pipe: int; $146 // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/language/standard.py // Line 44: return (x + div - 1) // div mov (16|M0) r114.0<1>:q r8.0<4;1,0>:uw {I@2} // ALU pipe: int; $143 mov (16|M16) r8.0<4>:uw r1.16<1;1,0>:uw // ALU pipe: int; $143 mov (16|M16) r112.0<1>:q r8.0<4;1,0>:uw {I@1} // ALU pipe: int; $143 // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/ops/matmul.py // Line 119: for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): (W&f2.0) jmpi _0_050 // ALU pipe: int; $147 // B019: Preds:{B018}, Succs:{B030} _0_051: mov (32|M0) r54.0<1>:uw 0x0:uw // ALU pipe: int; $149 mov (32|M0) r55.0<1>:uw 0x0:uw // ALU pipe: int; $150 mov (32|M0) r56.0<1>:uw 0x0:uw // ALU pipe: int; $151 mov (32|M0) r57.0<1>:uw 0x0:uw // ALU pipe: int; $152 mov (32|M0) r58.0<1>:hf 0x0:hf // ALU pipe: float; $153 mov (32|M0) r59.0<1>:hf 0x0:hf // ALU pipe: float; $154 mov (32|M0) r60.0<1>:hf 0x0:hf // ALU pipe: float; $155 mov (32|M0) r61.0<1>:hf 0x0:hf // ALU pipe: float; $156 (W) jmpi _0_052 // $157 // B020: Preds:{B018}, Succs:{B021, B022} _0_050: // Line 114: rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) (W) shl (1|M0) r5.7<1>:d r0.6<0;1,0>:d 4:w // ALU pipe: int; $162 // Line 113: rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) (W) cmp (32|M0) (eq)f1.0 null<1>:d r5.1<0;1,0>:d 0:w // ALU pipe: int; $167 // Line 114: rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) or (32|M0) r10.0<1>:d r5.7<0;1,0>:d r12.0<1;1,0>:d {I@2} // ALU pipe: int; $163 // Line 117: B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) (W) mul (16|M0) acc0.0<1>:d r10.0<1;1,0>:d r5.8<0;1,0>:uw {Compacted,I@1} // ALU pipe: int; $165 macl (16|M0) r8.0<1>:d r10.0<1;1,0>:d r5.4<0;1,0>:d {Compacted} // ALU pipe: int; $165 (W) mul (16|M16) acc0.0<1>:d r11.0<1;1,0>:d r5.8<0;1,0>:uw {Compacted} // ALU pipe: int; $165 // Line 113: rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) macl (16|M16) r9.0<1>:d r11.0<1;1,0>:d r5.4<0;1,0>:d {Compacted} // ALU pipe: int; $167 // Line 136: B += BLOCK_K * SPLIT_K * stride_bk (W) shl (1|M0) r5.2<1>:d r5.4<0;1,0>:d 4:w // ALU pipe: int; $160 // Line 113: rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) (W&~f1.0) jmpi _0_053 // ALU pipe: int; $168 // B021: Preds:{B020}, Succs:{B023} _0_054: mov (32|M0) r10.0<1>:d -1:w {Compacted} // ALU pipe: int; $170 (W) jmpi _0_055 // $171 // B022: Preds:{B020}, Succs:{B023} _0_053: (W) asr (1|M0) r5.12<1>:d r5.8<0;1,0>:d 31:w // ALU pipe: int; $173 (W) mov (1|M0) r5.8<1>:d (abs)r5.1<0;1,0>:d // ALU pipe: int; $174 add (32|M0) r12.0<1>:d r5.12<0;1,0>:d r116.0<1;1,0>:d {Compacted,I@2} // ALU pipe: int; $175 (W) mov (1|M0) r5.4<1>:f r5.8<0;1,0>:ud {I@2} // ALU pipe: float; $177 (W) mov (1|M0) r5.4<1>:df r5.8<0;1,0>:ud {F@1} // ALU pipe: long; $179 (W) math.inv (1|M0) r5.4<1>:f r5.4<0;1,0>:f // ALU pipe: math; $178 xor (32|M0) r10.0<1>:d r12.0<1;1,0>:d r5.12<0;1,0>:d {Compacted,I@1} // ALU pipe: int; $176 (W) mov (1|M0) r5.5<1>:df r5.4<0;1,0>:f {M@1} // ALU pipe: long; $181 (W) mov (1|M0) r12.0<1>:df -r5.4<0;1,0>:df {A@1} // ALU pipe: long; $180 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $182 (W) mov (1|M0) r5.4<1>:df 0x3FF0000000004000:df {A@1} // ALU pipe: long; $183 (W) mad (1|M0) r22.0<1>:df r5.4<0;0>:df r5.5<0;0>:df r12.0<0>:df {L@1} // ALU pipe: long; $183 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $184 mov (16|M0) r12.0<2>:ud r10.0<1;1,0>:ud {Compacted,A@1} // ALU pipe: int; $185 mov (16|M0) r19.0<1>:df r12.0<2;1,0>:ud {I@1} // ALU pipe: long; $185 mov (16|M16) r12.0<2>:ud r11.0<1;1,0>:ud {Compacted,L@1} // ALU pipe: int; $185 mov (16|M16) r17.0<1>:df r12.0<2;1,0>:ud {I@1} // ALU pipe: long; $185 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $186 mul (16|M0) acc0.0<1>:df r5.5<0;1,0>:df r19.0<1;1,0>:df {A@1} // ALU pipe: long; $187 mul (16|M16) acc2.0<1>:df r5.5<0;1,0>:df r17.0<1;1,0>:df {L@2} // ALU pipe: long; $187 mad (16|M0) r17.0<1>:df acc0.0<1;0>:df acc0.0<1;0>:df r22.0<0>:df // ALU pipe: long; $188 mad (16|M16) r19.0<1>:df acc2.0<1;0>:df acc2.0<1;0>:df r22.0<0>:df // ALU pipe: long; $188 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $189 mov (16|M0) r12.0<2>:ud r17.0<1;1,0>:df {A@1} // ALU pipe: int; $190 mov (16|M0) r14.0<1>:ud r12.0<2;1,0>:ud {Compacted,I@1} // ALU pipe: int; $190 mov (16|M16) r12.0<2>:ud r19.0<1;1,0>:df {L@1} // ALU pipe: int; $190 (W) mov (1|M0) r5.4<1>:d (abs)r5.1<0;1,0>:d // ALU pipe: int; $191 mov (16|M16) r15.0<1>:ud r12.0<2;1,0>:ud {Compacted,I@2} // ALU pipe: int; $190 (W) mul (16|M0) acc0.0<1>:d r5.4<0;1,0>:d r14.0<2;1,0>:uw {Compacted,I@2} // ALU pipe: int; $191 macl (16|M0) r16.0<1>:d r5.4<0;1,0>:d r14.0<1;1,0>:d {Compacted} // ALU pipe: int; $191 (W) mul (16|M16) acc0.0<1>:d r5.4<0;1,0>:d r15.0<2;1,0>:uw {Compacted,I@3} // ALU pipe: int; $191 macl (16|M16) r17.0<1>:d r5.4<0;1,0>:d r15.0<1;1,0>:d {Compacted} // ALU pipe: int; $192 add3 (32|M0) r12.0<1>:d r10.0<1;0>:d r5.12<0;0>:d -r16.0<1>:d {I@1} // ALU pipe: int; $192 xor (32|M0) r10.0<1>:d r12.0<1;1,0>:d r5.12<0;1,0>:d {Compacted,I@1} // ALU pipe: int; $193 // B023: Preds:{B022, B021}, Succs:{B024, B025} _0_055: // Line 117: B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) add (32|M0) r12.0<1>:d r8.0<1;1,0>:d r10.0<1;1,0>:d {Compacted,I@1} // ALU pipe: int; $196 mov (16|M0) r8.0<2>:ud r12.0<1;1,0>:ud {Compacted,I@1} // ALU pipe: int; $198 shl (16|M0) r14.0<1>:q r8.0<2;1,0>:d 1:w {Compacted,I@1} // ALU pipe: int; $198 mov (16|M16) r8.0<2>:ud r13.0<1;1,0>:ud {Compacted} // ALU pipe: int; $198 // Line 112: ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) (W) cmp (32|M0) (eq)f0.0 null<1>:d r5.0<0;1,0>:d 0:w // ALU pipe: int; $201 // Line 117: B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) shl (16|M16) r10.0<1>:q r8.0<2;1,0>:d 1:w {Compacted,I@2} // ALU pipe: int; $198 add (16|M0) r22.0<1>:q r14.0<1;1,0>:q r4.6<0;1,0>:q {Compacted,$2.dst} // ALU pipe: int; $199 add (16|M16) r24.0<1>:q r10.0<1;1,0>:q r4.6<0;1,0>:q {Compacted,I@2} // ALU pipe: int; $199 // Line 112: ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) (W&~f0.0) jmpi _0_056 // ALU pipe: int; $202 // B024: Preds:{B023}, Succs:{B026} _0_057: mov (32|M0) r8.0<1>:d -1:w {Compacted} // ALU pipe: int; $204 (W) jmpi _0_058 // $205 // B025: Preds:{B023}, Succs:{B026} _0_056: (W) mov (1|M0) r5.8<1>:d (abs)r5.0<0;1,0>:d // ALU pipe: int; $208 (W) asr (1|M0) r5.12<1>:d r5.13<0;1,0>:d 31:w // ALU pipe: int; $207 (W) mov (1|M0) r5.4<1>:f r5.8<0;1,0>:ud {I@2} // ALU pipe: float; $211 add (32|M0) r10.0<1>:d r5.12<0;1,0>:d r64.0<1;1,0>:d {Compacted,I@1} // ALU pipe: int; $209 (W) math.inv (1|M0) r5.4<1>:f r5.4<0;1,0>:f {F@1} // ALU pipe: math; $212 (W) mov (1|M0) r5.4<1>:df r5.8<0;1,0>:ud // ALU pipe: long; $213 xor (32|M0) r8.0<1>:d r10.0<1;1,0>:d r5.12<0;1,0>:d {Compacted,I@1} // ALU pipe: int; $210 (W) mov (1|M0) r5.5<1>:df r5.4<0;1,0>:f {M@1} // ALU pipe: long; $215 (W) mov (1|M0) r10.0<1>:df -r5.4<0;1,0>:df {A@1} // ALU pipe: long; $214 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $216 (W) mov (1|M0) r5.4<1>:df 0x3FF0000000004000:df {A@1} // ALU pipe: long; $217 (W) mad (1|M0) r20.0<1>:df r5.4<0;0>:df r5.5<0;0>:df r10.0<0>:df {L@1} // ALU pipe: long; $217 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $218 mov (16|M0) r10.0<2>:ud r8.0<1;1,0>:ud {Compacted,A@1} // ALU pipe: int; $219 mov (16|M0) r17.0<1>:df r10.0<2;1,0>:ud {I@1} // ALU pipe: long; $219 mov (16|M16) r10.0<2>:ud r9.0<1;1,0>:ud {Compacted,L@1} // ALU pipe: int; $219 mov (16|M16) r15.0<1>:df r10.0<2;1,0>:ud {I@1} // ALU pipe: long; $219 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $220 mul (16|M0) acc0.0<1>:df r5.5<0;1,0>:df r17.0<1;1,0>:df {A@1} // ALU pipe: long; $221 mul (16|M16) acc2.0<1>:df r5.5<0;1,0>:df r15.0<1;1,0>:df {L@2} // ALU pipe: long; $221 mad (16|M0) r15.0<1>:df acc0.0<1;0>:df acc0.0<1;0>:df r20.0<0>:df // ALU pipe: long; $222 mad (16|M16) r17.0<1>:df acc2.0<1;0>:df acc2.0<1;0>:df r20.0<0>:df // ALU pipe: long; $222 (W) xor (1|M0) cr0.0<1>:ud cr0.0<0;1,0>:ud 0x30:uw {Compacted,A@1} // $223 mov (16|M0) r10.0<2>:ud r15.0<1;1,0>:df {A@1} // ALU pipe: int; $224 mov (16|M0) r12.0<1>:ud r10.0<2;1,0>:ud {Compacted,I@1} // ALU pipe: int; $224 mov (16|M16) r10.0<2>:ud r17.0<1;1,0>:df {L@1} // ALU pipe: int; $224 (W) mov (1|M0) r5.4<1>:d (abs)r5.0<0;1,0>:d // ALU pipe: int; $225 mov (16|M16) r13.0<1>:ud r10.0<2;1,0>:ud {Compacted,I@2} // ALU pipe: int; $224 (W) mul (16|M0) acc0.0<1>:d r5.4<0;1,0>:d r12.0<2;1,0>:uw {Compacted,I@2} // ALU pipe: int; $225 macl (16|M0) r14.0<1>:d r5.4<0;1,0>:d r12.0<1;1,0>:d {Compacted} // ALU pipe: int; $225 (W) mul (16|M16) acc0.0<1>:d r5.4<0;1,0>:d r13.0<2;1,0>:uw {Compacted,I@3} // ALU pipe: int; $225 macl (16|M16) r15.0<1>:d r5.4<0;1,0>:d r13.0<1;1,0>:d {Compacted} // ALU pipe: int; $226 add3 (32|M0) r10.0<1>:d r8.0<1;0>:d r5.12<0;0>:d -r14.0<1>:d {I@1} // ALU pipe: int; $226 xor (32|M0) r8.0<1>:d r10.0<1;1,0>:d r5.12<0;1,0>:d {Compacted,I@1} // ALU pipe: int; $227 // B026: Preds:{B025, B024}, Succs:{B027} _0_058: // Line 116: A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) (W) mul (16|M0) acc0.0<1>:d r8.0<1;1,0>:d r5.6<0;1,0>:uw {Compacted,I@1} // ALU pipe: int; $234 macl (16|M0) r10.0<1>:d r8.0<1;1,0>:d r5.3<0;1,0>:d {Compacted} // ALU pipe: int; $234 (W) mul (16|M16) acc0.0<1>:d r9.0<1;1,0>:d r5.6<0;1,0>:uw {Compacted} // ALU pipe: int; $234 macl (16|M16) r11.0<1>:d r9.0<1;1,0>:d r5.3<0;1,0>:d {Compacted} // ALU pipe: int; $235 add3 (32|M0) r8.0<1>:d r10.0<1;0>:d r5.7<0;0>:d r2.0<1>:d {I@1} // ALU pipe: int; $235 mov (16|M0) r2.0<2>:ud r8.0<1;1,0>:ud {Compacted,I@1} // ALU pipe: int; $237 shl (16|M0) r10.0<1>:q r2.0<2;1,0>:d 1:w {Compacted,I@1} // ALU pipe: int; $237 mov (16|M16) r2.0<2>:ud r9.0<1;1,0>:ud {Compacted} // ALU pipe: int; $237 shl (16|M16) r8.0<1>:q r2.0<2;1,0>:d 1:w {Compacted,I@1} // ALU pipe: int; $237 shl (16|M0) r2.0<2>:d r114.0<1;1,0>:q 3:w // ALU pipe: int; $239 mov (16|M0) r62.0<1>:d r2.0<2;1,0>:d {Compacted,I@1} // ALU pipe: int; $239 shl (16|M16) r2.0<2>:d r112.0<1;1,0>:q 3:w // ALU pipe: int; $239 mov (16|M16) r63.0<1>:d r2.0<2;1,0>:d {Compacted,I@1} // ALU pipe: int; $239 and (32|M0) r1.0<1>:w r1.0<1;1,0>:w 31:w // ALU pipe: int; $246 add (16|M16) r36.0<1>:q r8.0<1;1,0>:q r4.5<0;1,0>:q {Compacted} // ALU pipe: int; $238 shl (32|M0) r8.0<1>:d r62.0<1;1,0>:d 1:w {Compacted,I@3} // ALU pipe: int; $241 add (16|M0) r34.0<1>:q r10.0<1;1,0>:q r4.5<0;1,0>:q {Compacted} // ALU pipe: int; $238 shl (32|M0) r10.0<1>:w r1.0<1;1,0>:w 1:w {I@4} // ALU pipe: int; $247 and (32|M0) r2.0<1>:d r8.0<1;1,0>:d 496:w {Compacted,I@3} // ALU pipe: int; $242 shl (32|M0) r8.0<1>:d r1.0<1;1,0>:uw 1:w // ALU pipe: int; $249 mov (32|M0) r98.0<1>:ud r10.0<1;1,0>:uw {I@3} // ALU pipe: int; $253 or (32|M0) r1.0<1>:w r10.0<1;1,0>:w 256:w // ALU pipe: int; $254 shl (32|M0) r10.0<1>:d r6.0<1;1,0>:d 1:w {Compacted} // ALU pipe: int; $260 and (32|M0) r12.0<1>:d r6.0<1;1,0>:d 15:w {Compacted} // ALU pipe: int; $259 (W) mov (1|M0) r5.3<1>:d 32:w {Compacted} // ALU pipe: int; $261 or (32|M0) r104.0<1>:d r8.0<1;1,0>:d 64:w {Compacted,I@6} // ALU pipe: int; $250 or (32|M0) r102.0<1>:d r8.0<1;1,0>:d 128:w {Compacted} // ALU pipe: int; $251 or (32|M0) r100.0<1>:d r8.0<1;1,0>:d 192:w {Compacted} // ALU pipe: int; $252 or (32|M0) r94.0<1>:d r8.0<1;1,0>:d 320:w {Compacted} // ALU pipe: int; $256 or (32|M0) r92.0<1>:d r8.0<1;1,0>:d 384:w {Compacted} // ALU pipe: int; $257 or (32|M0) r90.0<1>:d r8.0<1;1,0>:d 448:w {Compacted} // ALU pipe: int; $258 bfn.(s0&s1|s2) (32|M0) r8.0<1>:ud r10.0<1;0>:ud r5.3<0;0>:ud r12.0<1>:ud {I@7} // ALU pipe: int; $262 shl (32|M0) r6.0<1>:d r8.0<1;1,0>:d 1:w {Compacted,I@1} // ALU pipe: int; $263 // Line 119: for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): mov (32|M0) r52.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $279 mov (32|M0) r50.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $280 mov (32|M0) r48.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $281 mov (32|M0) r46.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $282 mov (32|M0) r44.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $283 mov (32|M0) r42.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $284 mov (32|M0) r40.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $285 mov (32|M0) r38.0<1>:f 0x0:f {Compacted} // ALU pipe: float; $286 // Line 116: A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) or (32|M0) r110.0<1>:d r2.0<1;1,0>:d 8:w {Compacted} // ALU pipe: int; $243 or (32|M0) r108.0<1>:d r2.0<1;1,0>:d 1024:w {Compacted} // ALU pipe: int; $244 or (32|M0) r106.0<1>:d r2.0<1;1,0>:d 1032:w {Compacted} // ALU pipe: int; $245 or (32|M0) r72.0<1>:d r2.0<1;1,0>:d 4:w {Compacted} // ALU pipe: int; $272 or (32|M0) r70.0<1>:d r2.0<1;1,0>:d 12:w {Compacted} // ALU pipe: int; $273 or (32|M0) r68.0<1>:d r2.0<1;1,0>:d 1028:w {Compacted} // ALU pipe: int; $274 or (32|M0) r66.0<1>:d r2.0<1;1,0>:d 1036:w {Compacted} // ALU pipe: int; $275 mov (32|M0) r96.0<1>:ud r1.0<1;1,0>:uw // ALU pipe: int; $255 or (32|M0) r88.0<1>:d r6.0<1;1,0>:d 1024:w {Compacted,I@7} // ALU pipe: int; $264 or (32|M0) r86.0<1>:d r6.0<1;1,0>:d 1056:w {Compacted} // ALU pipe: int; $265 or (32|M0) r84.0<1>:d r6.0<1;1,0>:d 1152:w {Compacted} // ALU pipe: int; $266 or (32|M0) r82.0<1>:d r6.0<1;1,0>:d 1184:w {Compacted} // ALU pipe: int; $267 or (32|M0) r80.0<1>:d r6.0<1;1,0>:d 1280:w {Compacted} // ALU pipe: int; $268 or (32|M0) r78.0<1>:d r6.0<1;1,0>:d 1312:w {Compacted} // ALU pipe: int; $269 or (32|M0) r76.0<1>:d r6.0<1;1,0>:d 1408:w {Compacted} // ALU pipe: int; $270 or (32|M0) r74.0<1>:d r6.0<1;1,0>:d 1440:w {Compacted} // ALU pipe: int; $271 // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/language/standard.py // Line 44: return (x + div - 1) // div (W) asr (1|M0) r5.6<1>:d r5.6<0;1,0>:d 4:w // ALU pipe: int; $231 // File: /home/jovyan/workspace/triton/intel-xpu-backend-for-triton/python/triton/ops/matmul.py // Line 119: for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): (W) mov (1|M0) r5.4<1>:d 0:w {Compacted} // ALU pipe: int; $287 // Line 116: A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) (W) shl (1|M0) r5.1<1>:q r5.2<0;1,0>:d 1:w {Compacted} // ALU pipe: int; $277 // B027: Preds:{B028, B026}, Succs:{B028, B029} _0_059: // Line 121: a = tl.load(A) load.ugm.d32x4.a64 (32|M0) r6:8 [r34:4] {I@4,$4} // ex_desc:0x0; desc:0x8803580 // $290 shr (32|M0) r14.0<2>:w r6.0<1;1,0>:ud 16:w {$4.dst} // ALU pipe: int; $292 mov (32|M0) r26.0<2>:w r6.0<1;1,0>:d // ALU pipe: int; $291 mov (32|M0) r26.1<2>:w r14.0<2;1,0>:w {I@2} // ALU pipe: int; $292 shr (32|M0) r14.0<2>:w r8.0<1;1,0>:ud 16:w // ALU pipe: int; $295 mov (32|M0) r20.0<2>:w r8.0<1;1,0>:d // ALU pipe: int; $294 mov (32|M0) r20.1<2>:w r14.0<2;1,0>:w {I@2} // ALU pipe: int; $295 shr (32|M0) r14.0<2>:w r10.0<1;1,0>:ud 16:w // ALU pipe: int; $298 mov (32|M0) r18.0<2>:w r10.0<1;1,0>:d // ALU pipe: int; $297 mov (32|M0) r18.1<2>:w r14.0<2;1,0>:w {I@2} // ALU pipe: int; $298 shr (32|M0) r14.0<2>:w r12.0<1;1,0>:ud 16:w // ALU pipe: int; $301 mov (32|M0) r16.0<2>:w r12.0<1;1,0>:d // ALU pipe: int; $300 mov (32|M0) r16.1<2>:w r14.0<2;1,0>:w {I@2} // ALU pipe: int; $301 sync.nop null {Compacted,$5.src} // $303 (W) send.slm (1|M0) r1 r0 null:0 0x0 0x0210001F {$6} // wr:1+0, rd:1; fence.slm.none.group // $303 (W) mov (8|M0) null<1>:ud r1.0<1;1,0>:ud {Compacted,$6.dst} // memory fence commit; ALU pipe: int; $304 (W) mov (1|M0) r1.2<1>:f 0x0:f {Compacted,I@1} // signal barrier payload init; (0x00000000:f); ALU pipe: float; $304 (W) mov (2|M0) r1.10<1>:ub r0.11<0;1,0>:ub {F@1} // signal barrier payload (nprods, ncons); ALU pipe: int; $304 (W) send.gtwy (1|M0) null r1 null:0 0x0 0x02000004 {I@1,$7} // wr:1+0, rd:0; signal barrier // $304 (W) sync.bar 0x0 {Compacted} // $304 // Line 122: b = tl.load(B) load.ugm.d32x4.a64 (32|M0) r6:8 [r22:4] {$8} // ex_desc:0x0; desc:0x8803580 // $326 // Line 121: a = tl.load(A) store.slm.d32.a32 (32|M0) [r2:2] r26:2 {$9} // ex_desc:0x0; desc:0x4000504 // $311 store.slm.d32.a32 (32|M0) [r72:2] r20:2 {$10} // ex_desc:0x0; desc:0x4000504 // $314 store.slm.d32.a32 (32|M0) [r110:2] r18:2 {$11} // ex_desc:0x0; desc:0x4000504 // $321 store.slm.d32.a32 (32|M0) [r70:2] r16:2 {$12} // ex_desc:0x0; desc:0x4000504 // $324 // Line 122: b = tl.load(B) shr (32|M0) r14.0<2>:w r6.0<1;1,0>:ud 16:w {$8.dst} // ALU pipe: int; $328 mov (32|M0) r26.0<2>:w r6.0<1;1,0>:d {$9.src} // ALU pipe: int; $327 mov (32|M0) r26.1<2>:w r14.0<2;1,0>:w {I@2} // ALU pipe: int; $328 shr (32|M0) r14.0<2>:w r8.0<1;1,0>:ud 16:w // ALU pipe: int; $331 mov (32|M0) r20.0<2>:w r8.0<1;1,0>:d {$10.src} // ALU pipe: int; $330 mov (32|M0) r20.1<2>:w r14.0<2;1,0>:w {I@2} // ALU pipe: int; $331 shr (32|M0) r14.0<2>:w r10.0<1;1,0>:ud 16:w // ALU pipe: int; $334 mov (32|M0) r18.0<2>:w r10.0<1;1,0>:d {$11.src} // ALU pipe: int; $333 mov (32|M0) r18.1<2>:w r14.0<2;1,0>:w {I@2} // ALU pipe: int; $334 shr (32|M0) r14.0<2>:w r12.0<1;1,0>:ud 16:w // ALU pipe: int; $337 store.slm.d32.a32 (32|M0) [r108:2] r26:2 {$13} // ex_desc:0x0; desc:0x4000504 // $345 mov (32|M0) r16.0<2>:w r12.0<1;1,0>:d {$12.src} // ALU pipe: int; $336 store.slm.d32.a32 (32|M0) [r68:2] r20:2 {$14} // ex_desc:0x0; desc:0x4000504 // $348 mov (32|M0) r16.1<2>:w r14.0<2;1,0>:w {I@2} // ALU pipe: int; $337 store.slm.d32.a32 (32|M0) [r106:2] r18:2 {$15} // ex_desc:0x0; desc:0x4000504 // $355 store.slm.d32.a32 (32|M0) [r66:2] r16:2 {A@1,$0} // ex_desc:0x0; desc:0x4000504 // $358 sync.nop null {Compacted,$7.src} // $359 (W) send.slm (1|M0) r1 r0 null:0 0x0 0x0210001F {$1} // wr:1+0, rd:1; fence.slm.none.group // $359 (W) mov (8|M0) null<1>:ud r1.0<1;1,0>:ud {Compacted,$1.dst} // memory fence commit; ALU pipe: int; $360 (W) mov (1|M0) r1.2<1>:f 0x0:f {Compacted,I@1} // signal barrier payload init; (0x00000000:f); ALU pipe: float; $360 (W) mov (2|M0) r1.10<1>:ub r0.11<0;1,0>:ub {F@1} // signal barrier payload (nprods, ncons); ALU pipe: int; $360 (W) send.gtwy (1|M0) null r1 null:0 0x0 0x02000004 {I@1,$5} // wr:1+0, rd:0; signal barrier // $360 (W) sync.bar 0x0 {Compacted} // $360 // Line 121: a = tl.load(A) load.slm.d16u32.a32 (32|M0) r6:2 [r98:2] {$3} // ex_desc:0x0; desc:0x4200B00 // $362 // Line 122: b = tl.load(B) load.slm.d16u32.a32 (32|M0) r14:2 [r88:2] {$4} // ex_desc:0x0; desc:0x4200B00 // $379 // Line 132: acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32) mov (32|M0) r26.0<1>:f r52.0<1;1,0>:f {Compacted,$13.src} // ALU pipe: float; $420 // Line 121: a = tl.load(A) mov (32|M0) r58.0<1>:w r6.0<2;1,0>:w {$3.dst} // ALU pipe: int; $363 load.slm.d16u32.a32 (32|M0) r6:2 [r104:2] {I@1,$6} // ex_desc:0x0; desc:0x4200B00 // $364 // Line 132: acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32) mov (32|M0) r28.0<1>:f r50.0<1;1,0>:f {Compacted} // ALU pipe: float; $421 mov (32|M0) r30.0<1>:f r48.0<1;1,0>:f {Compacted} // ALU pipe: float; $422 // Line 121: a = tl.load(A) mov (32|M0) r59.0<1>:w r6.0<2;1,0>:w {$6.dst} // ALU pipe: int; $365 load.slm.d16u32.a32 (32|M0) r6:2 [r102:2] {I@1,$8} // ex_desc:0x0; desc:0x4200B00 // $366 // Line 132: acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32) mov (32|M0) r32.0<1>:f r46.0<1;1,0>:f {Compacted} // ALU pipe: float; $423 mov (32|M0) r16.0<1>:f r42.0<1;1,0>:f {Compacted,$0.src} // ALU pipe: float; $425 // Line 121: a = tl.load(A) mov (32|M0) r60.0<1>:w r6.0<2;1,0>:w {$8.dst} // ALU pipe: int; $367 load.slm.d16u32.a32 (32|M0) r6:2 [r100:2] {I@1,$9} // ex_desc:0x0; desc:0x4200B00 // $368 // Line 132: acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32) mov (32|M0) r18.0<1>:f r40.0<1;1,0>:f {Compacted,$15.src} // ALU pipe: float; $426 mov (32|M0) r20.0<1>:f r38.0<1;1,0>:f {Compacted,$14.src} // ALU pipe: float; $427 // Line 121: a = tl.load(A) mov (32|M0) r61.0<1>:w r6.0<2;1,0>:w {$9.dst} // ALU pipe: int; $369 load.slm.d16u32.a32 (32|M0) r6:2 [r96:2] {I@1,$10} // ex_desc:0x0; desc:0x4200B00 // $370 // Line 119: for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): (W) add (1|M0) r5.4<1>:d r5.4<0;1,0>:d 1:w {Compacted} // ALU pipe: int; $439 (W) cmp (32|M0) (eq)f3.0 null<1>:d r5.4<0;1,0>:d r5.6<0;1,0>:d {I@1} // ALU pipe: int; $440 // Line 121: a = tl.load(A) mov (32|M0) r54.0<1>:w r6.0<2;1,0>:w {$10.dst} // ALU pipe: int; $371 load.slm.d16u32.a32 (32|M0) r6:2 [r94:2] {I@1,$11} // ex_desc:0x0; desc:0x4200B00 // $372 mov (32|M0) r55.0<1>:w r6.0<2;1,0>:w {$11.dst} // ALU pipe: int; $373 load.slm.d16u32.a32 (32|M0) r6:2 [r92:2] {I@1,$12} // ex_desc:0x0; desc:0x4200B00 // $374 mov (32|M0) r56.0<1>:w r6.0<2;1,0>:w {$12.dst} // ALU pipe: int; $375 load.slm.d16u32.a32 (32|M0) r6:2 [r90:2] {I@1,$13} // ex_desc:0x0; desc:0x4200B00 // $376 mov (32|M0) r57.0<1>:w r6.0<2;1,0>:w {$13.dst} // ALU pipe: int; $377 // Line 122: b = tl.load(B) sync.nop null {Compacted,I@1} // $380 mov (32|M0) r6.0<2>:hf r14.0<2;1,0>:hf {$4.dst} // ALU pipe: float; $380 load.slm.d16u32.a32 (32|M0) r14:2 [r86:2] {F@1,$14} // ex_desc:0x0; desc:0x4200B00 // $381 mov (32|M0) r6.1<2>:uw r14.0<2;1,0>:uw {$14.dst} // ALU pipe: int; $382 load.slm.d16u32.a32 (32|M0) r14:2 [r84:2] {I@1,$15} // ex_desc:0x0; desc:0x4200B00 // $383 mov (32|M0) r8.0<2>:hf r14.0<2;1,0>:hf {$15.dst} // ALU pipe: float; $384 load.slm.d16u32.a32 (32|M0) r14:2 [r82:2] {F@1,$0} // ex_desc:0x0; desc:0x4200B00 // $385 mov (32|M0) r8.1<2>:uw r14.0<2;1,0>:uw {$0.dst} // ALU pipe: int; $386 load.slm.d16u32.a32 (32|M0) r14:2 [r80:2] {I@1,$1} // ex_desc:0x0; desc:0x4200B00 // $387 mov (32|M0) r10.0<2>:hf r14.0<2;1,0>:hf {$1.dst} // ALU pipe: float; $388 load.slm.d16u32.a32 (32|M0) r14:2 [r78:2] {F@1,$3} // ex_desc:0x0; desc:0x4200B00 // $389 mov (32|M0) r10.1<2>:uw r14.0<2;1,0>:uw {$3.dst} // ALU pipe: int; $390 load.slm.d16u32.a32 (32|M0) r14:2 [r76:2] {I@1,$4} // ex_desc:0x0; desc:0x4200B00 // $391 mov (32|M0) r12.0<2>:hf r14.0<2;1,0>:hf {$4.dst} // ALU pipe: float; $392 load.slm.d16u32.a32 (32|M0) r14:2 [r74:2] {F@1,$6} // ex_desc:0x0; desc:0x4200B00 // $393 mov (32|M0) r12.1<2>:uw r14.0<2;1,0>:uw {$6.dst} // ALU pipe: int; $394 // Line 132: acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32) mov (32|M0) r14.0<1>:f r44.0<1;1,0>:f {Compacted,I@1} // ALU pipe: float; $424 dpas.8x8 (16|M0) r26:f r26:f r6:hf r58.0:hf {Atomic,Compacted,F@1} // $428 R{} IR{}{E:13,E:3,E:13,}, R{} IR{}{O:13,O:3,O:13,}, {BC=2} dpas.8x8 (16|M0) r14:f r14:f r6:hf r54.0:hf {Compacted,$7} // $433 R{} IR{}{E:7,E:3,E:11,}, R{} IR{}{O:7,O:3,O:11,}, {BC=2} mov (32|M0) r52.0<1>:f r26.0<1;1,0>:f {Compacted,$7.dst} // ALU pipe: float; $429 mov (32|M0) r50.0<1>:f r28.0<1;1,0>:f {Compacted} // ALU pipe: float; $430 mov (32|M0) r48.0<1>:f r30.0<1;1,0>:f {Compacted} // ALU pipe: float; $431 mov (32|M0) r46.0<1>:f r32.0<1;1,0>:f {Compacted} // ALU pipe: float; $432 mov (32|M0) r44.0<1>:f r14.0<1;1,0>:f {Compacted} // ALU pipe: float; $434 mov (32|M0) r42.0<1>:f r16.0<1;1,0>:f {Compacted} // ALU pipe: float; $435 mov (32|M0) r40.0<1>:f r18.0<1;1,0>:f {Compacted} // ALU pipe: float; $436 mov (32|M0) r38.0<1>:f r20.0<1;1,0>:f {Compacted} // ALU pipe: float; $437 // Line 119: for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): (W&f3.0) jmpi _0_060 // ALU pipe: int; $441 // B028: Preds:{B027}, Succs:{B027} _0_061: // Line 135: A += BLOCK_K * SPLIT_K * stride_ak add (16|M0) r34.0<1>:q r34.0<1;1,0>:q 32:w {Compacted} // ALU pipe: int; $444 add (16|M16) r36.0<1>:q r36.0<1;1,0>:q 32:w {Compacted} // ALU pipe: int; $444 // Line 136: B += BLOCK_K * SPLIT_K * stride_bk add (16|M0) r22.0<1>:q r22.0<1;1,0>:q r5.1<0;1,0>:q {Compacted} // ALU pipe: int; $446 add (16|M16) r24.0<1>:q r24.0<1;1,0>:q r5.1<0;1,0>:q {Compacted} // ALU pipe: int; $446 // Line 119: for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): (W) jmpi _0_059 // $448 // B029: Preds:{B027}, Succs:{B030} _0_060: // Line 137: acc = acc.to(C.dtype.element_ty) mov (16|M0) r54.0<1>:hf r52.0<1;1,0>:f {F@7} // ALU pipe: float; $451 mov (16|M16) r54.16<1>:hf r53.0<1;1,0>:f // ALU pipe: float; $451 mov (16|M0) r55.0<1>:hf r50.0<1;1,0>:f {F@7} // ALU pipe: float; $452 mov (16|M16) r55.16<1>:hf r51.0<1;1,0>:f // ALU pipe: float; $452 mov (16|M0) r56.0<1>:hf r48.0<1;1,0>:f {F@7} // ALU pipe: float; $453 mov (16|M16) r56.16<1>:hf r49.0<1;1,0>:f // ALU pipe: float; $453 mov (16|M0) r57.0<1>:hf r46.0<1;1,0>:f {F@7} // ALU pipe: float; $454 mov (16|M16) r57.16<1>:hf r47.0<1;1,0>:f // ALU pipe: float; $454 mov (16|M0) r58.0<1>:hf r44.0<1;1,0>:f // ALU pipe: float; $455 mov (16|M16) r58.16<1>:hf r45.0<1;1,0>:f // ALU pipe: float; $455 mov (16|M0) r59.0<1>:hf r42.0<1;1,0>:f // ALU pipe: float; $456 mov (16|M16) r59.16<1>:hf r43.0<1;1,0>:f // ALU pipe: float; $456 mov (16|M0) r60.0<1>:hf r40.0<1;1,0>:f // ALU pipe: float; $457 mov (16|M16) r60.16<1>:hf r41.0<1;1,0>:f // ALU pipe: float; $457 mov (16|M0) r61.0<1>:hf r38.0<1;1,0>:f // ALU pipe: float; $458 mov (16|M16) r61.16<1>:hf r39.0<1;1,0>:f // ALU pipe: float; $458 // B030: Preds:{B029, B019}, Succs:{B031, B032} _0_052: // Line 142: mask = (rm < M)[:, None] & (rn < N)[None, :] cmp (32|M0) (lt)f1.0 null<1>:d r116.0<1;1,0>:d r5.1<0;1,0>:d // ALU pipe: int; $463 (f1.0) cmp (32|M0) (lt)f1.0 null<1>:d r64.0<1;1,0>:d r5.0<0;1,0>:d // ALU pipe: int; $464 mov (16|M0) r2.0<1>:d r114.0<2;1,0>:d {Compacted} // ALU pipe: int; $461 mov (16|M16) r3.0<1>:d r112.0<2;1,0>:d {Compacted} // ALU pipe: int; $462 sync.nop null {Compacted,$5.src} // $466 (W) send.slm (1|M0) r1 r0 null:0 0x0 0x0210001F {$8} // wr:1+0, rd:1; fence.slm.none.group // $466 (W) mov (8|M0) null<1>:ud r1.0<1;1,0>:ud {Compacted,$8.dst} // memory fence commit; ALU pipe: int; $467 (W) mov (1|M0) r1.2<1>:f 0x0:f {Compacted,I@1} // signal barrier payload init; (0x00000000:f); ALU pipe: float; $467 (W) mov (2|M0) r1.10<1>:ub r0.11<0;1,0>:ub {F@1} // signal barrier payload (nprods, ncons); ALU pipe: int; $467 (W) send.gtwy (1|M0) null r1 null:0 0x0 0x02000004 {I@1,$9} // wr:1+0, rd:0; signal barrier // $467 (W) sync.bar 0x0 {Compacted} // $467 // Line 145: tl.store(C, acc, mask=mask) and (32|M0) (eq)f0.0 null<1>:d r2.0<1;1,0>:d 16:w // ALU pipe: int; $469 (W) mov (1|M0) r5.0<1>:hf 0x0:hf // ALU pipe: float; $471 mov (16|M0) r2.0<1>:d r114.0<2;1,0>:d {Compacted} // ALU pipe: int; $472 mov (16|M16) r3.0<1>:d r112.0<2;1,0>:d {Compacted} // ALU pipe: int; $472 sync.nop null {Compacted,F@1} // $471 (f0.0) sel (32|M0) r1.0<1>:w r5.0<0;1,0>:w 48:w {$9.src} // ALU pipe: int; $471 shl (32|M0) r8.0<1>:d r2.0<1;1,0>:d 1:w {Compacted,I@2} // ALU pipe: int; $473 mov (32|M0) r1.0<2>:b r1.0<1;1,0>:w {I@2} // ALU pipe: int; $475 and (32|M0) r6.0<1>:d r8.0<1;1,0>:d 30:w {Compacted,I@2} // ALU pipe: int; $474 (W) mov (1|M0) r5.0<1>:hf 0x60:hf // ALU pipe: float; $481 mov (32|M0) r8.0<1>:d r1.0<2;1,0>:ub {I@2} // ALU pipe: int; $476 add (32|M0) r10.0<1>:d r8.0<1;1,0>:d r6.0<1;1,0>:d {Compacted,I@1} // ALU pipe: int; $477 (f0.0) sel (32|M0) r1.0<1>:w r5.0<0;1,0>:w -112:w {F@1} // ALU pipe: int; $481 mov (32|M0) r8.0<1>:ud r54.0<1;1,0>:uw // ALU pipe: int; $479 mov (32|M0) r1.0<2>:b r1.0<1;1,0>:w {I@2} // ALU pipe: int; $482 store.slm.d16u32.a32 (32|M0) [r10:2] r8:2 {I@2,$10} // ex_desc:0x0; desc:0x4000B04 // $480 (W) mov (1|M0) r5.0<1>:hf 0xFFC0:hf // ALU pipe: float; $488 sync.nop null {Compacted,I@1} // $483 mov (32|M0) r8.0<1>:d r1.0<2;1,0>:ub {$10.src} // ALU pipe: int; $483 add (32|M0) r10.0<1>:d r8.0<1;1,0>:d r6.0<1;1,0>:d {Compacted,I@1} // ALU pipe: int; $484 (f0.0) sel (32|M0) r1.0<1>:w r5.0<0;1,0>:w -16:w {F@1} // ALU pipe: int; $488 mov (32|M0) r8.0<1>:ud r55.0<1;1,0>:uw // ALU pipe: int; $486 mov (32|M0) r1.0<2>:b r1.0<1;1,0>:w {I@2} // ALU pipe: int; $489 store.slm.d16u32.a32 (32|M0) [r10:2] r8:2 {I@2,$11} // ex_desc:0x0; desc:0x4000B04 // $487 sync.nop null {Compacted,I@1} // $490 mov (32|M0) r8.0<1>:d r1.0<2;1,0>:ub {$11.src} // ALU pipe: int; $490 add (32|M0) r10.0<1>:d r8.0<1;1,0>:d r6.0<1;1,0>:d {Compacted,I@1} // ALU pipe: int; $491 mov (32|M0) r8.0<1>:ud r56.0<1;1,0>:uw // ALU pipe: int; $493 (W) mov (1|M0) r5.0<1>:hf 0x120:hf // ALU pipe: float; $495 store.slm.d16u32.a32 (32|M0) [r10:2] r8:2 {A@1,$12} // ex_desc:0x0; desc:0x4000B04 // $494 sync.nop null {Compacted,F@1} // $495 (f0.0) sel (32|M0) r8.0<1>:d r5.0<0;1,0>:w 336:w {$12.src} // ALU pipe: int; $495 add (32|M0) r10.0<1>:d r8.0<1;1,0>:d r6.0<1;1,0>:d {Compacted,I@1} // ALU pipe: int; $496 mov (32|M0) r8.0<1>:ud r57.0<1;1,0>:uw // ALU pipe: int; $498 (W) mov (1|M0) r5.0<1>:hf 0x180:hf // ALU pipe: float; $500 store.slm.d16u32.a32 (32|M0) [r10:2] r8:2 {A@1,$13} // ex_desc:0x0; desc:0x4000B04 // $499 sync.nop null {Compacted,F@1} // $500 (f0.0) sel (32|M0) r8.0<1>:d r5.0<0;1,0>:w 432:w {$13.src} // ALU pipe: int; $500 add (32|M0) r10.0<1>:d r8.0<1;1,0>:d r6.0<1;1,0>:d {Compacted,I@1} // ALU pipe: int; $501 mov (32|M0) r8.0<1>:ud r58.0<1;1,0>:uw // ALU pipe: int; $503 (W) mov (1|M0) r5.0<1>:hf 0x1E0:hf // ALU pipe: float; $505 store.slm.d16u32.a32 (32|M0) [r10:2] r8:2 {A@1,$14} // ex_desc:0x0; desc:0x4000B04 // $504 sync.nop null {Compacted,F@1} // $505 (f0.0) sel (32|M0) r8.0<1>:d r5.0<0;1,0>:w 528:w {$14.src} // ALU pipe: int; $505 add (32|M0) r10.0<1>:d r8.0<1;1,0>:d r6.0<1;1,0>:d {Compacted,I@1} // ALU pipe: int; $506 mov (32|M0) r8.0<1>:ud r59.0<1;1,0>:uw // ALU pipe: int; $508 (W) mov (1|M0) r5.0<1>:hf 0x240:hf // ALU pipe: float; $510 store.slm.d16u32.a32 (32|M0) [r10:2] r8:2 {A@1,$15} // ex_desc:0x0; desc:0x4000B04 // $509 sync.nop null {Compacted,F@1} // $510 (f0.0) sel (32|M0) r8.0<1>:d r5.0<0;1,0>:w 624:w {$15.src} // ALU pipe: int; $510 add (32|M0) r10.0<1>:d r8.0<1;1,0>:d r6.0<1;1,0>:d {Compacted,I@1} // ALU pipe: int; $511 mov (32|M0) r8.0<1>:ud r60.0<1;1,0>:uw // ALU pipe: int; $513 (W) mov (1|M0) r5.0<1>:hf 0x2A0:hf // ALU pipe: float; $515 store.slm.d16u32.a32 (32|M0) [r10:2] r8:2 {A@1,$0} // ex_desc:0x0; desc:0x4000B04 // $514 sync.nop null {Compacted,F@1} // $515 (f0.0) sel (32|M0) r8.0<1>:d r5.0<0;1,0>:w 720:w {$0.src} // ALU pipe: int; $515 add (32|M0) r10.0<1>:d r8.0<1;1,0>:d r6.0<1;1,0>:d {Compacted,I@1} // ALU pipe: int; $516 mov (32|M0) r6.0<1>:ud r61.0<1;1,0>:uw // ALU pipe: int; $518 store.slm.d16u32.a32 (32|M0) [r10:2] r6:2 {A@1,$1} // ex_desc:0x0; desc:0x4000B04 // $519 (W) send.slm (1|M0) r1 r0 null:0 0x0 0x0210001F {$3} // wr:1+0, rd:1; fence.slm.none.group // $520 (W) mov (8|M0) null<1>:ud r1.0<1;1,0>:ud {Compacted,$3.dst} // memory fence commit; ALU pipe: int; $521 (W) mov (1|M0) r1.2<1>:f 0x0:f {Compacted,I@1} // signal barrier payload init; (0x00000000:f); ALU pipe: float; $521 (W) mov (2|M0) r1.10<1>:ub r0.11<0;1,0>:ub {F@1} // signal barrier payload (nprods, ncons); ALU pipe: int; $521 (W) send.gtwy (1|M0) null r1 null:0 0x0 0x02000004 {I@1,$4} // wr:1+0, rd:0; signal barrier // $521 (W) sync.bar 0x0 {Compacted} // $521 (~f1.0) goto (32|M0) _0_062 _0_062 // ALU pipe: int; $522 // B031: [inDivergent], Preds:{B030}, Succs:{B032} _0_063: shr (32|M0) r6.0<1>:d r2.0<1;1,0>:ud 1:w {$1.src} // ALU pipe: int; $524 shl (32|M0) r2.0<1>:d r62.0<1;1,0>:d 1:w {Compacted} // ALU pipe: int; $526 and (32|M0) r8.0<1>:d r6.0<1;1,0>:d 15:w {Compacted,I@2} // ALU pipe: int; $525 and (32|M0) r6.0<1>:d r2.0<1;1,0>:d 16:w {Compacted,I@2} // ALU pipe: int; $527 mad (32|M0) r2.0<1>:d r6.0<1;0>:d r8.0<1;0>:d 48:w {I@1} // ALU pipe: int; $528 load.slm.d32x4.a32 (32|M0) r9:8 [r2:2] {I@1,$5} // ex_desc:0x0; desc:0x4803500 // $529 // Line 141: C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) (W) mul (16|M0) acc0.0<1>:d r64.0<1;1,0>:d r5.10<0;1,0>:uw // ALU pipe: int; $531 macl (16|M0) r2.0<1>:d r64.0<1;1,0>:d r5.5<0;1,0>:d {$5.src} // ALU pipe: int; $531 (W) mul (16|M16) acc0.0<1>:d r65.0<1;1,0>:d r5.10<0;1,0>:uw // ALU pipe: int; $531 macl (16|M16) r3.0<1>:d r65.0<1;1,0>:d r5.5<0;1,0>:d // ALU pipe: int; $532 add (32|M0) r6.0<1>:d r2.0<1;1,0>:d r116.0<1;1,0>:d {Compacted,I@1} // ALU pipe: int; $532 sync.nop null {Compacted,I@1} // $534 mov (16|M0) r1.0<2>:ud r6.0<1;1,0>:ud {Compacted,$4.src} // ALU pipe: int; $534 shl (16|M0) r19.0<1>:q r1.0<2;1,0>:d 1:w {Compacted,I@1} // ALU pipe: int; $534 mov (16|M16) r1.0<2>:ud r7.0<1;1,0>:ud {Compacted} // ALU pipe: int; $534 shl (16|M16) r17.0<1>:q r1.0<2;1,0>:d 1:w {Compacted,I@1} // ALU pipe: int; $534 add (16|M0) r5.0<1>:q r19.0<1;1,0>:q r4.7<0;1,0>:q {Compacted,$2.dst} // ALU pipe: int; $535 add (16|M16) r7.0<1>:q r17.0<1;1,0>:q r4.7<0;1,0>:q {Compacted,I@2} // ALU pipe: int; $535 // Line 145: tl.store(C, acc, mask=mask) store.ugm.d32x4.a64 (32|M0) [r5:4] r9:8 {A@1,$5} // ex_desc:0x0; desc:0x8003584 // $537 // B032: Preds:{B031, B030}, Succs:{} _0_062: join (32|M0) L6824 // L6824: // Line 144: if SPLIT_K == 1: (W) mov (16|M0) r112.0<1>:f r0.0<1;1,0>:f {Compacted} // ALU pipe: float; $540 (W) send.gtwy (1|M0) null r112 null:0 0x0 0x02000010 {EOT,F@1,$0} // wr:1+0, rd:0; end of thread // $540 L6848: nop // $540 //.BankConflicts: 4 //.ByteRMWs: 3 // //.numALUInst: 451 //.accSubDef: 5 //.accSubUse: 9 //.accSubCandidateDef: 5 //.accSubCandidateUse: 9 // // //.singlePipeAtOneDistNum: 136 //.allAtOneDistNum: 54 //.syncInstCount: 15 //.tokenReuseCount: 0 //.AfterWriteTokenDepCount: 29 //.AfterReadTokenDepCount: 24 ```
whitneywhtsang commented 6 months ago

@chengjunlu Please confirm if the provided OCL builtins are enough to have DPAS working with 32 threads per warp.

chengjunlu commented 5 months ago

@chengjunlu Please confirm if the provided OCL builtins are enough to have DPAS working with 32 threads per warp.

I don't think there is such API in OCL for now. I will create a new issue to track the OCL interface for the sub-group-size=32 DPAS.

For this issue, I think we can close it as the GenISA DPAS works as expected with sub-group-size=32.