intel / mlir-extensions

Intel® Extension for MLIR. A staging ground for MLIR dialects and tools for Intel devices using the MLIR toolchain.
Other
118 stars 45 forks source link

[xegpu-to-vc-func] Is `transpose_bit_width=16` supported? #895

Open dchigarev opened 1 week ago

dchigarev commented 1 week ago

I'm trying to transpose a 16x16xf16 matrix using xegpu.load_nd %0 {transpose = array<i64: 1, 0>, transpose_bit_width = 16 : i32} but the values are being transposed in the '32bit manner' (although transpose_bit_width=16). Is this an expected behavior or a bug?

Reproducer ```mlir // RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ // RUN: --runner imex-cpu-runner -e main \ // RUN: --entry-point-result=void \ // RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime module attributes {gpu.container_module} { gpu.module @transpose_16bit_loadnd attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { gpu.func @transpose_16bit_loadnd(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> %1 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> %2 = xegpu.load_nd %0 {transpose = array, transpose_bit_width = 16 : i32} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> xegpu.store_nd %2, %1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16> gpu.return } } func.func @main() { %c_gen_int = arith.constant 0 : i1 %cf_lower = arith.constant -0.5 : f32 %cf_upper = arith.constant 0.5 : f32 %result = memref.alloc() : memref<16x16xf16> %resultc = memref.alloc() : memref<16x16xf16> %r_r = memref.cast %result : memref<16x16xf16> to memref<*xf16> call @fillResource1DRandomF16(%r_r, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> () %cast2 = memref.cast %result : memref<16x16xf16> to memref<*xf16> call @printMemrefF16(%cast2) : (memref<*xf16>) -> () %gpu_result_index = gpu.alloc host_shared () : memref<16x16xf16> %gpu_result = gpu.alloc host_shared () : memref<16x16xf16> memref.copy %result, %gpu_result_index : memref<16x16xf16> to memref<16x16xf16> gpu.launch_func @transpose_16bit_loadnd::@transpose_16bit_loadnd blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%gpu_result_index : memref<16x16xf16>, %gpu_result : memref<16x16xf16>) memref.copy %gpu_result, %resultc : memref<16x16xf16> to memref<16x16xf16> %cast1 = memref.cast %resultc : memref<16x16xf16> to memref<*xf16> call @printMemrefF16(%cast1) : (memref<*xf16>) -> () return } func.func private @printMemrefF16(memref<*xf16>) func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface} } ```
Output ``` Original matrix: [[-0.0335999, -0.108459, 0.454346, 0.173096, 0.291992, -0.437744, 0.150879, 0.243774, -0.118896, 0.390625, -0.337402, 0.184204, 0.148682, 0.109863, 0.131592, 0.167603], [-0.383789, 0.182007, 0.157837, 0.016922, 0.403564, -0.355957, -0.465576, -0.0371094, -0.167603, -0.0213776, -0.107849, -0.364502, -0.49292, -0.40625, -0.474121, 0.259277], [0.212158, -0.0585938, 0.307861, 0.357178, -0.0243073, -0.301514, 0.157715, 0.0397949, -0.115845, -0.0805054, 0.354248, 0.288818, -0.387695, 0.265137, -0.191528, 0.23584], [0.186035, 0.13623, 0.164795, 0.321777, -0.131348, 0.189575, 0.437744, -0.437256, -0.488281, 0.104675, 0.223145, 0.468994, 0.471436, 0.289551, -0.388184, 0.24231], [0.152832, -0.233521, 0.0818481, -0.445312, -0.0191803, 0.349854, 0.472168, -0.358398, -0.220459, 0.244751, -0.0543518, 0.000132799, 0.288086, 0.0359192, -0.0933838, 0.165527], [-0.0643311, -0.368896, 0.398438, 0.125854, 0.174438, 0.010788, 0.0161896, -0.0637817, -0.450928, -0.256104, -0.0791016, 0.197266, -0.274658, -0.172607, -0.0960693, 0.376221], [0.326416, 0.428223, 0.0844116, -0.111023, 0.288574, -0.287109, 0.147461, 0.489258, -0.109314, 0.0188751, -0.375732, 0.175903, -0.309082, -0.172852, -0.499756, -0.102051], [-0.395508, -0.160034, -0.210571, 0.429688, -0.302246, -0.0577393, -0.0242767, -0.174194, 0.21228, 0.110107, 0.34082, 0.348877, -0.255371, 0.156738, 0.143066, -0.0538025], [0.43042, -0.496338, 0.0446472, 0.376465, -0.153564, -0.231934, 0.322266, -0.2771, -0.272949, 0.0265045, 0.293457, -0.207886, -0.248657, -0.244141, 0.118164, -0.167969], [-0.318359, 0.33252, 0.192261, -0.403564, 0.23877, 0.078064, 0.400391, -0.290771, -0.12323, 0.0836182, 0.265381, -0.337891, -0.431396, 0.262207, 0.0490723, 0.0157623], [0.310059, -0.481201, -0.0360413, 0.371582, 0.39624, 0.413086, 0.307861, 0.499756, 0.0454102, -0.2771, 0.352783, -0.0714111, 0.184082, 0.4729, -0.0998535, -0.420166], [0.445312, -0.265381, -0.182983, -0.249146, -0.437256, -0.298828, -0.418701, -0.0429688, 0.0679932, -0.256836, -0.38208, 0.378174, 0.0784302, -0.149658, 0.232544, -0.249634], [-0.318115, -0.179443, -0.33667, -0.43335, -0.129028, 0.0360718, -0.48999, 0.333984, 0.356934, 0.238159, -0.198608, 0.0809326, 0.0897827, 0.209839, -0.0469055, -0.409668], [0.345947, -0.0787354, 0.0486755, 0.098938, 0.0684204, 0.227295, 0.0414429, 0.465576, 0.204834, 0.419189, 0.297119, -0.347412, -0.0586853, 0.239746, 0.174805, 0.0572205], [0.261963, -0.0251923, 0.481201, -0.470703, -0.0614014, 0.305176, -0.439209, -0.0430603, 0.346924, -0.411377, 0.00965118, 0.00774765, -0.378906, 0.466309, -0.13623, -0.0748901], [-0.241821, 0.000869751, -0.336914, -0.0773926, 0.469238, -0.218994, 0.362305, 0.00957489, -0.297852, 0.0365906, -0.382568, 0.308594, 0.134277, -0.322998, -0.445557, 0.158325]] Transposed matrix (with transpose_bit_width=16, but seems like it's still 32): [[-0.0335999, -0.108459, -0.383789, 0.182007, 0.212158, -0.0585938, 0.186035, 0.13623, 0.152832, -0.233521, -0.0643311, -0.368896, 0.326416, 0.428223, -0.395508, -0.160034], [0.43042, -0.496338, -0.318359, 0.33252, 0.310059, -0.481201, 0.445312, -0.265381, -0.318115, -0.179443, 0.345947, -0.0787354, 0.261963, -0.0251923, -0.241821, 0.000869751], [0.454346, 0.173096, 0.157837, 0.016922, 0.307861, 0.357178, 0.164795, 0.321777, 0.0818481, -0.445312, 0.398438, 0.125854, 0.0844116, -0.111023, -0.210571, 0.429688], [0.0446472, 0.376465, 0.192261, -0.403564, -0.0360413, 0.371582, -0.182983, -0.249146, -0.33667, -0.43335, 0.0486755, 0.098938, 0.481201, -0.470703, -0.336914, -0.0773926], [0.291992, -0.437744, 0.403564, -0.355957, -0.0243073, -0.301514, -0.131348, 0.189575, -0.0191803, 0.349854, 0.174438, 0.010788, 0.288574, -0.287109, -0.302246, -0.0577393], [-0.153564, -0.231934, 0.23877, 0.078064, 0.39624, 0.413086, -0.437256, -0.298828, -0.129028, 0.0360718, 0.0684204, 0.227295, -0.0614014, 0.305176, 0.469238, -0.218994], [0.150879, 0.243774, -0.465576, -0.0371094, 0.157715, 0.0397949, 0.437744, -0.437256, 0.472168, -0.358398, 0.0161896, -0.0637817, 0.147461, 0.489258, -0.0242767, -0.174194], [0.322266, -0.2771, 0.400391, -0.290771, 0.307861, 0.499756, -0.418701, -0.0429688, -0.48999, 0.333984, 0.0414429, 0.465576, -0.439209, -0.0430603, 0.362305, 0.00957489], [-0.118896, 0.390625, -0.167603, -0.0213776, -0.115845, -0.0805054, -0.488281, 0.104675, -0.220459, 0.244751, -0.450928, -0.256104, -0.109314, 0.0188751, 0.21228, 0.110107], [-0.272949, 0.0265045, -0.12323, 0.0836182, 0.0454102, -0.2771, 0.0679932, -0.256836, 0.356934, 0.238159, 0.204834, 0.419189, 0.346924, -0.411377, -0.297852, 0.0365906], [-0.337402, 0.184204, -0.107849, -0.364502, 0.354248, 0.288818, 0.223145, 0.468994, -0.0543518, 0.000132799, -0.0791016, 0.197266, -0.375732, 0.175903, 0.34082, 0.348877], [0.293457, -0.207886, 0.265381, -0.337891, 0.352783, -0.0714111, -0.38208, 0.378174, -0.198608, 0.0809326, 0.297119, -0.347412, 0.00965118, 0.00774765, -0.382568, 0.308594], [0.148682, 0.109863, -0.49292, -0.40625, -0.387695, 0.265137, 0.471436, 0.289551, 0.288086, 0.0359192, -0.274658, -0.172607, -0.309082, -0.172852, -0.255371, 0.156738], [-0.248657, -0.244141, -0.431396, 0.262207, 0.184082, 0.4729, 0.0784302, -0.149658, 0.0897827, 0.209839, -0.0586853, 0.239746, -0.378906, 0.466309, 0.134277, -0.322998], [0.131592, 0.167603, -0.474121, 0.259277, -0.191528, 0.23584, -0.388184, 0.24231, -0.0933838, 0.165527, -0.0960693, 0.376221, -0.499756, -0.102051, 0.143066, -0.0538025], [0.118164, -0.167969, 0.0490723, 0.0157623, -0.0998535, -0.420166, 0.232544, -0.249634, -0.0469055, -0.409668, 0.174805, 0.0572205, -0.13623, -0.0748901, -0.445557, 0.158325]] ```
chencha3 commented 1 week ago

On current public platform, only 32-bit is supported.