Open wyang20170113 opened 3 months ago
Looks like the insert_element in MMA16816SmemLoader::loadX4 is trying to insert at index 32 when it only has a vector of 4 elements when lowering the following:
%72 = triton_gpu.local_load %68 : !tt.memdesc<16x16x32xf32, #shared> -> tensor<16x16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 32}>> loc(#loc54)
%73 = triton_gpu.local_load %71 : !tt.memdesc<1x32x16xf32, #shared1> -> tensor<1x32x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 32}>> loc(#loc56)
%74 = tt.dot %72, %73, %cst, inputPrecision = tf32 : tensor<16x16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 32}>> * tensor<1x32x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 32}>> -> tensor<16x16x16xf32, #mma> loc(#loc56)
Crash is happening here:
It is crashing because the canonWidth is 32 which goes out of the bounds of the retElems SmallVector that contains the 4 elements for the loadX4. I think if you were using bf16 it would probably be taking the ldmatrix path instead.
I am not sure how to fix this one.
Running the above code lead to following error: """ fp32[constexpr[8], constexpr[16], constexpr[64]] fp32[constexpr[64], constexpr[64]] fp32[constexpr[8], constexpr[16], constexpr[64]] Segmentation fault (core dumped) """ It seem that tl.dot is a problem since using the following two line to replace the tl.dot make the code work!!! o = d_mask[:, :, :, None] * weight[None, None, :, :] o = tl.sum(o, axis=2) I am curious what happened and why is that? Thanks very much and appreciate your comments.