Closed ESI-SYD closed 6 months ago
https://github.com/intel/intel-xpu-backend-for-triton/pull/370 fixes a subset of the test_softmax
tests.
@whitneywhtsang can we close this issue?
@whitneywhtsang can we close this issue?
No, because only a subset of the test cases are passing.
I reduced the problem to a particular operation in the "_blocksparse_softmax_fwd" (in softmax.py) kernel. The kernel is launched using a grid of [3, 4, 2]:
print("grid: ", grid)
_blocksparse_softmax_fwd[grid](
out, a, a.stride(0), lut, #
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn#
scale, #
is_causal, #
BLOCK_SIZE=block, #
ROW_SIZE=next_power_of_2(maxlut), #
IS_DENSE=is_dense, #
num_warps=num_warps(maxlut) #
)
Here is the relevant trace:
___________________________ test_softmax[1-4-False] ____________________________
BLOCK = 1, WIDTH = 4, is_dense = False, device = 'xpu', Z = 2, H = 3
is_causal = True, scale = 0.4
...
grid: [3, 4, 2] <<<<<<<<
Inside the _blocksparse_softmax_fwd
kernel we have this code section
if IS_DENSE:
ns = tl.arange(0, ROW_SIZE)
else:
a = tl.num_programs(0) #incorrect value is 96 and should be 3.
b = tl.num_programs(1) # this returns 4, which is correct.
tl.device_print("a: ", a)
tl.device_print("b: ", b)
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0)
ns = start_n * BLOCK_SIZE + lane_n
Here the 2 print statements yield:
pid (1, 1, 0) idx () a: 96
pid (1, 1, 0) idx () a: 96
pid (1, 1, 0) idx () a: 96
pid (1, 1, 0) idx () a: 96
pid (1, 1, 0) idx () a: 96
....
pid (1, 1, 0) idx () b: 4
pid (1, 1, 0) idx () b: 4
pid (1, 1, 0) idx () b: 4
pid (1, 1, 0) idx () b: 4
pid (1, 1, 0) idx () b: 4
...
So tl.num_programs(0)
prints 96 instead of 3. Note that the call tl.num_programs(1)
actually yields the correct value (4).
Hard coding the correct value (3) for tl.num_programs(0)
causes the kernel to yield the expected result.
Verified that the LLVM IR produced by Triton is correct. The tl.num_programs(0)
is correctly lowered to the OpenCL function size_T get_global_size(unsigned int)
:
%50 = call i64 @_Z15get_global_sizej(i32 0), !dbg !30
%51 = trunc i64 %50 to i32, !dbg !30
%52 = call i32 (ptr addrspace(2), ...) @_Z18__spirv_ocl_printf(ptr addrspace(2) @printfFormat_1, i32 %8, i32 %10, i32 %12, ptr addrspace(2) @printfPrefix_1, i32 %51)
Note:
% c++filt _Z15get_global_sizej
get_global_size(unsigned int)
The issue is a problem in the GEN dialect lowering of genx.grid.dim
. This problem will be addressed by PR https://github.com/intel/llvm/pull/12709.
This case can pass on SPIRV path but fail on LLVM path.