iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.56k stars 571 forks source link

`vector.transfer_read` legalization fails with PR [3540](https://github.com/google/iree/pull/3540) #3565

Closed MaheshRavishankar closed 3 years ago

MaheshRavishankar commented 3 years ago

The ConvertToSPIRV pass is failing to legalize vector.transfer_read. This is the input to the pass

#map0 = affine_map<(d0, d1) -> (d1)>                                                                                                                                                                                                                                                                    [204/1826]
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>                                                                                                                                                                                                                                                                      
#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>                                                                                                                                                                                                                                                                      
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>                                                                                                                                                                                                                                                                      

module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader, CooperativeMatrixNV], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, NVIDIA:DiscreteGPU, {cooperative_matrix_properties_nv = [{a_type = f16, b_type = f16, c_type = f16, k_size = 16 : i32, m_size = 16 : i32, n_
size = 16 : i32, result_type = f16, scope = 3 : i32}], max_compute_shared_memory_size = 49152 : i32, max_compute_workgroup_invocations = 1024 : i32, max_compute_workgroup_size = dense<[2147483647, 65535, 65535]> : vector<3xi32>, subgroup_size = 32 : i32}>} {                                                
  func @matmul_static_shape() attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}, vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {                                                                                                                             
    %c4096 = constant 4096 : index                                                                                                                                                                                                                                                                                
    %c0 = constant 0 : index                                                                                                                                                                                                                                                                                      
    %c16 = constant 16 : index                                                                                                                                                                                                                                                                                    
    %c32 = constant 32 : index                                                                                                                                                                                                                                                                                    
    %c48 = constant 48 : index                                                                                                                                                                                                                                                                                    
    %c64 = constant 64 : index                                                                                                                                                                                                                                                                                    
    %cst = constant dense<0.000000e+00> : vector<4xf32>                                                                                                                                                                                                                                                           
    %c8 = constant 8 : index                                                                                                                                                                                                                                                                                      
    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x512xvector<4xf32>>                                                                                                                                                                     
    %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x512xvector<4xf32>>                                                                                                                                                                     
    %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x512xvector<4xf32>>                                                                                                                                                                     
    %3 = "gpu.block_id"() {dimension = "x"} : () -> index                                                                                                                                                                                                                                                         
    %4 = "gpu.block_id"() {dimension = "y"} : () -> index                                                                                                                                                                                                                                                         
    scf.for %arg0 = %c0 to %c4096 step %c32 {                                                                                                                                                                                                                                                                     
      %5 = muli %4, %c64 : index                                                                                                                                                                                                                                                                                  
      %6 = muli %3, %c64 : index                                                                                                                                                                                                                                                                                  
      %7 = divi_signed %arg0, %c8 : index                                                                                                                                                                                                                                                                         
      %8 = vector.transfer_read %0[%5, %7], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                              
      %9 = addi %arg0, %c16 : index                                                                                                                                                                                                                                                                               
      %10 = divi_signed %9, %c8 : index                                                                                                                                                                                                                                                                           
      %11 = vector.transfer_read %0[%5, %10], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                            
      %12 = addi %5, %c16 : index                                                                                                                                                                                                                                                                                 
      %13 = vector.transfer_read %0[%12, %7], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                            
      %14 = vector.transfer_read %0[%12, %10], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                           
      %15 = addi %5, %c32 : index                                                                                                                                                                                                                                                                                 
      %16 = vector.transfer_read %0[%15, %7], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                            
      %17 = vector.transfer_read %0[%15, %10], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                           
      %18 = addi %5, %c48 : index                                                                                                                                                                                                                                                                                 
      %19 = vector.transfer_read %0[%18, %7], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                            
      %20 = vector.transfer_read %0[%18, %10], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                           
      %21 = divi_signed %6, %c8 : index                                                                                                                                                                                                                                                                           
      %22 = vector.transfer_read %1[%arg0, %21], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                         
      %23 = addi %6, %c16 : index                                                                                                                                                                                                                                                                                 
      %24 = divi_signed %23, %c8 : index                                                                                                                                                                                                                                                                          
      %25 = vector.transfer_read %1[%arg0, %24], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                         
      %26 = addi %6, %c32 : index                                                                                                                                                                                                                                                                                 
      %27 = divi_signed %26, %c8 : index                                                                                                                                                                                                                                                                          
      %28 = vector.transfer_read %1[%arg0, %27], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                         
      %29 = addi %6, %c48 : index                                                                                                                                                                                                                                                                                 
      %30 = divi_signed %29, %c8 : index                                                                                                                                                                                                                                                                          
      %31 = vector.transfer_read %1[%arg0, %30], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                         
      %32 = vector.transfer_read %1[%9, %21], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                            
      %33 = vector.transfer_read %1[%9, %24], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                            
      %34 = vector.transfer_read %1[%9, %27], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                            
      %35 = vector.transfer_read %1[%9, %30], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                            
      %36 = vector.transfer_read %2[%5, %21], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                            
      %37 = vector.transfer_read %2[%5, %24], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                            
      %38 = vector.transfer_read %2[%5, %27], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                            
      %39 = vector.transfer_read %2[%5, %30], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>                                                                                                                                                                                                            
      %40 = vector.transfer_read %2[%12, %21], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %41 = vector.transfer_read %2[%12, %24], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %42 = vector.transfer_read %2[%12, %27], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %43 = vector.transfer_read %2[%12, %30], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %44 = vector.transfer_read %2[%15, %21], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %45 = vector.transfer_read %2[%15, %24], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %46 = vector.transfer_read %2[%15, %27], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %47 = vector.transfer_read %2[%15, %30], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %48 = vector.transfer_read %2[%18, %21], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %49 = vector.transfer_read %2[%18, %24], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %50 = vector.transfer_read %2[%18, %27], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %51 = vector.transfer_read %2[%18, %30], %cst : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %52 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %8, %22, %36 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %53 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %11, %32, %52 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %54 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %8, %25, %37 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %55 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %11, %33, %54 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %56 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %8, %28, %38 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %57 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %11, %34, %56 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
     %58 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %8, %31, %39 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                            
      %59 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %11, %35, %58 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %60 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %13, %22, %40 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %61 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %14, %32, %60 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %62 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %13, %25, %41 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %63 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %14, %33, %62 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %64 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %13, %28, %42 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %65 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %14, %34, %64 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %66 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %13, %31, %43 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %67 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %14, %35, %66 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %68 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %16, %22, %44 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %69 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %17, %32, %68 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %70 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %16, %25, %45 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %71 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %17, %33, %70 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %72 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %16, %28, %46 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %73 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %17, %34, %72 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %74 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %16, %31, %47 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %75 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %17, %35, %74 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %76 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %19, %22, %48 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %77 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %20, %32, %76 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %78 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %19, %25, %49 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %79 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %20, %33, %78 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %80 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %19, %28, %50 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %81 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %20, %34, %80 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %82 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %19, %31, %51 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      %83 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %20, %35, %82 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                                                                                                           
      vector.transfer_write %53, %2[%5, %21] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>                                                                                                                                                                                                                  
      vector.transfer_write %55, %2[%5, %24] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>                                                                                                                                                                                                                  
      vector.transfer_write %57, %2[%5, %27] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>                                                                                                                                                                                                                  
      vector.transfer_write %59, %2[%5, %30] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>                                                                                                                                                                                                                  
      vector.transfer_write %61, %2[%12, %21] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>                                                                                                                                                                                                                 
      vector.transfer_write %63, %2[%12, %24] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>                                                                                                                                                                                                                 
      vector.transfer_write %65, %2[%12, %27] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>                                                                                                                                                                                                                 
      vector.transfer_write %67, %2[%12, %30] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>                                                                                                                                                                                                                 
      vector.transfer_write %69, %2[%15, %21] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>                                                                                                                                                                                                                 
      vector.transfer_write %71, %2[%15, %24] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>                                                                                                                                                                                                                 
      vector.transfer_write %73, %2[%15, %27] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>                                                                                                                                                                                                                 
      vector.transfer_write %75, %2[%15, %30] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>                                                                                                                                                                                                                 
      vector.transfer_write %77, %2[%18, %21] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>                                                                                                                                                                                                                 
      vector.transfer_write %79, %2[%18, %24] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>                                                                                                                                                                                                                 
      vector.transfer_write %81, %2[%18, %27] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>                                                                                                                                                                                                                 
      vector.transfer_write %83, %2[%18, %30] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>                                                                                                                                                                                                                 
    }                                                                                                                                                                                                                                                                                                             
    return                                                                                                                                                                                                                                                                                                        
  }                                                                                                                                                                                                                                                                                                               
  func @matmul_static_shape__num_workgroups__(%arg0: !shapex.ranked_shape<[4096,4096]>, %arg1: !shapex.ranked_shape<[4096,4096]>, %arg2: !shapex.ranked_shape<[4096,4096]>) -> (index, index, index) attributes {sym_visibility = "private"} {                                                                    
    %c1 = constant 1 : index                                                                                                                                                                                                                                                                                      
    %c64 = constant 64 : index                                                                                                                                                                                                                                                                                    
    return %c64, %c64, %c1 : index, index, index                                                                                                                                                                                                                                                                  
  }                                                                                                                                                                                                                                                                                                               
  hal.interface @legacy_io attributes {sym_visibility = "private"} {                                                                                                                                                                                                                                              
    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"                                                                                                                                                                                                                            
    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"                                                                                                                                                                                                                            
    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"                                                                                                                                                                                                                           
  }                                                                                                                                                                                                                                                                                                               
}
MaheshRavishankar commented 3 years ago

Repro instructions (for now)

1) Use PR 3540 2) Also need patch D89744 (patch this on top of the third_party/llvm-project submodule)

iree-translate -pass-pipeline="iree-codegen-linalg-to-spirv-pipeline{use-vectorization}" iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir -mlir-disable-threading -print-ir-after-all
ThomasRaoux commented 3 years ago

This happens because [Float16] capability is not set making the cooperative matrix load illegal. I believe the capability is needed here.

MaheshRavishankar commented 3 years ago

I thought I had that. Thanks for the quick update here.

MaheshRavishankar commented 3 years ago

Cool. It works now.

MaheshRavishankar commented 3 years ago
#map0 = affine_map<(d0, d1) -> (d0, d1)>                                                                                                                                                                                                                                                                          
#map1 = affine_map<(d0, d1) -> (d1)>                                                                                                                                                                                                                                                                              
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>                                                                                                                                                                                                                                                                      
#map3 = affine_map<(d0, d1, d2) -> (d2, d1)>                                                                                                                                                                                                                                                                      
#map4 = affine_map<(d0, d1, d2) -> (d0, d1)>                                                                                                                                                                                                                                                                      

module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.5, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, Group
NonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers, VariablePointersStorageBuffer, CooperativeMatrixNV], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers, SPV_NV_cooperative_matr
ix]>, NVIDIA:DiscreteGPU, {cooperative_matrix_properties_nv = [{a_type = i8, b_type = i8, c_type = i32, k_size = 32 : i32, m_size = 8 : i32, n_size = 8 : i32, result_type = i32, scope = 3 : i32}, {a_type = f16, b_type = f16, c_type = f16, k_size = 16 : i32, m_size = 16 : i32, n_size = 16 : i32, result_typ
e = f16, scope = 3 : i32}, {a_type = f16, b_type = f16, c_type = f32, k_size = 16 : i32, m_size = 16 : i32, n_size = 16 : i32, result_type = f32, scope = 3 : i32}], max_compute_shared_memory_size = 49152 : i32, max_compute_workgroup_invocations = 1024 : i32, max_compute_workgroup_size = dense<[2147483647,
 65535, 65535]> : vector<3xi32>, subgroup_size = 32 : i32}>} {                                                                                                                                                                                                                                                    
  func @matmul_static_shape() attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}, vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {                                                                                                                            
    %c16 = constant 16 : index                                                                                                                                                                                                                                                                                    
    %c48 = constant 48 : index                                                                                                                                                                                                                                                                                    
    %c32 = constant 32 : index                                                                                                                                                                                                                                                                                    
    %c4096 = constant 4096 : index                                                                                                                                                                                                                                                                                
    %c128 = constant 128 : index                                                                                                                                                                                                                                                                                  
    %c64 = constant 64 : index                                                                                                                                                                                                                                                                                    
    %c2 = constant 2 : index                                                                                                                                                                                                                                                                                      
    %c0 = constant 0 : index                                                                                                                                                                                                                                                                                      
    %c-1 = constant -1 : index                                                                                                                                                                                                                                                                                    
    %c-128 = constant -128 : index                                                                                                                       
    %cst = constant 0.000000e+00 : f16                                                                                                                   
    %cst_0 = constant dense<0.000000e+00> : vector<4xf32>                                                                                                
    %c8 = constant 8 : index                                                                                                                                                                                                                                                                                      
    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16>                     
    %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16>
    %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x512xvector<4xf32>>                                                                                                                                                                     
    %3 = "gpu.block_id"() {dimension = "x"} : () -> index                                                                                                                                                                                                                                                         
    %4 = "gpu.block_id"() {dimension = "y"} : () -> index                                                                                                
    scf.for %arg0 = %c0 to %c4096 step %c32 {                                                                                                                                                                                                                                                                     
      %5 = muli %4, %c128 : index                                                                                                                        
      %6 = muli %3, %c128 : index                                                                                                                        
      %7 = alloc() : memref<128x32xf16, 3>                                                                                                                                                                                                                                                                        
      %8 = alloc() : memref<32x128xf16, 3>                                                                                                                                                                                                                                                                        
      %9 = "gpu.thread_id"() {dimension = "x"} : () -> index                                                                                                                                                                                                                                                      
      %10 = "gpu.block_dim"() {dimension = "x"} : () -> index                                                                                                                                                                                                                                                     
      %11 = "gpu.thread_id"() {dimension = "y"} : () -> index                                                                                                                                                                                                                                                     
      %12 = "gpu.block_dim"() {dimension = "y"} : () -> index                                                                                                                                                                                                                                                     
      %13 = "gpu.thread_id"() {dimension = "z"} : () -> index                                                                                            
      %14 = "gpu.block_dim"() {dimension = "z"} : () -> index                                                                                                                                                                                                                                                     
      %15 = muli %13, %12 : index                                                                                                                                                                                                                                                                                 
      %16 = addi %15, %11 : index                                                                                                                                                                                                                                                                                 
      %17 = muli %14, %12 : index                                                                                                                                                                                                                                                                                 
      %18 = muli %16, %10 : index                                                                                                                                                                                                                                                                                 
      %19 = addi %18, %9 : index                                                                                                                                                                                                                                                                                  
      %20 = muli %17, %10 : index                                                                                                                                                                                                                                                                                 
      scf.for %arg1 = %19 to %c4096 step %20 {                                                                                                                                                                                                                                                                    
        %123 = divi_signed %arg1, %c32 : index                                                                                                                                                                                                                                                                    
        %124 = remi_signed %arg1, %c32 : index
        %125 = addi %5, %123 : index
        %126 = addi %arg0, %124 : index
        %127 = load %0[%125, %126] : memref<4096x4096xf16>
        store %127, %7[%123, %124] : memref<128x32xf16, 3>
      }
      spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
      scf.for %arg1 = %19 to %c4096 step %20 {
        %123 = divi_signed %arg1, %c128 : index
        %124 = remi_signed %arg1, %c128 : index
        %125 = addi %arg0, %123 : index
        %126 = addi %6, %124 : index
        %127 = load %1[%125, %126] : memref<4096x4096xf16>
        store %127, %8[%123, %124] : memref<32x128xf16, 3>
      }
      spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
      %21 = gpu.subgroup_id : index
      %22 = divi_signed %21, %c2 : index
      %23 = muli %22, %c64 : index
      %24 = cmpi "slt", %22, %c0 : index
      %25 = subi %c-1, %22 : index
      %26 = select %24, %25, %22 : index
      %27 = divi_signed %26, %c2 : index
      %28 = subi %c-1, %27 : index
      %29 = select %24, %28, %27 : index
      %30 = muli %29, %c-128 : index                                                                                                                                                                                                                                                                    [269/1889]
      %31 = addi %23, %30 : index
      %32 = muli %21, %c64 : index
      %33 = cmpi "slt", %21, %c0 : index
      %34 = subi %c-1, %21 : index
      %35 = select %33, %34, %21 : index
      %36 = divi_signed %35, %c2 : index
      %37 = subi %c-1, %36 : index
      %38 = select %33, %37, %36 : index
      %39 = muli %38, %c-128 : index
      %40 = addi %32, %39 : index
      %41 = vector.transfer_read %7[%31, %c0], %cst {masked = [true, false]} : memref<128x32xf16, 3>, vector<16x16xf16>
      %42 = vector.transfer_read %7[%31, %c16], %cst {masked = [true, false]} : memref<128x32xf16, 3>, vector<16x16xf16>
      %43 = addi %31, %c16 : index
      %44 = vector.transfer_read %7[%43, %c0], %cst {masked = [true, false]} : memref<128x32xf16, 3>, vector<16x16xf16>
      %45 = vector.transfer_read %7[%43, %c16], %cst {masked = [true, false]} : memref<128x32xf16, 3>, vector<16x16xf16>
      %46 = addi %31, %c32 : index
      %47 = vector.transfer_read %7[%46, %c0], %cst {masked = [true, false]} : memref<128x32xf16, 3>, vector<16x16xf16>
      %48 = vector.transfer_read %7[%46, %c16], %cst {masked = [true, false]} : memref<128x32xf16, 3>, vector<16x16xf16>
      %49 = addi %31, %c48 : index
      %50 = vector.transfer_read %7[%49, %c0], %cst {masked = [true, false]} : memref<128x32xf16, 3>, vector<16x16xf16>
      %51 = vector.transfer_read %7[%49, %c16], %cst {masked = [true, false]} : memref<128x32xf16, 3>, vector<16x16xf16>
      %52 = vector.transfer_read %8[%c0, %40], %cst {masked = [false, true]} : memref<32x128xf16, 3>, vector<16x16xf16>
      %53 = addi %40, %c16 : index
      %54 = vector.transfer_read %8[%c0, %53], %cst {masked = [false, true]} : memref<32x128xf16, 3>, vector<16x16xf16>
      %55 = addi %40, %c32 : index
      %56 = vector.transfer_read %8[%c0, %55], %cst {masked = [false, true]} : memref<32x128xf16, 3>, vector<16x16xf16>
      %57 = addi %40, %c48 : index
      %58 = vector.transfer_read %8[%c0, %57], %cst {masked = [false, true]} : memref<32x128xf16, 3>, vector<16x16xf16>
      %59 = vector.transfer_read %8[%c16, %40], %cst {masked = [false, true]} : memref<32x128xf16, 3>, vector<16x16xf16>
      %60 = vector.transfer_read %8[%c16, %53], %cst {masked = [false, true]} : memref<32x128xf16, 3>, vector<16x16xf16>
      %61 = vector.transfer_read %8[%c16, %55], %cst {masked = [false, true]} : memref<32x128xf16, 3>, vector<16x16xf16>
      %62 = vector.transfer_read %8[%c16, %57], %cst {masked = [false, true]} : memref<32x128xf16, 3>, vector<16x16xf16>
      %63 = addi %5, %31 : index
      %64 = addi %6, %40 : index
      %65 = divi_signed %64, %c8 : index
      %66 = vector.transfer_read %2[%63, %65], %cst_0 : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %67 = addi %6, %53 : index
      %68 = divi_signed %67, %c8 : index
      %69 = vector.transfer_read %2[%63, %68], %cst_0 : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %70 = addi %6, %55 : index
      %71 = divi_signed %70, %c8 : index
      %72 = vector.transfer_read %2[%63, %71], %cst_0 : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %73 = addi %6, %57 : index
      %74 = divi_signed %73, %c8 : index
      %75 = vector.transfer_read %2[%63, %74], %cst_0 : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %76 = addi %5, %43 : index
      %77 = vector.transfer_read %2[%76, %65], %cst_0 : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %78 = vector.transfer_read %2[%76, %68], %cst_0 : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %79 = vector.transfer_read %2[%76, %71], %cst_0 : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %80 = vector.transfer_read %2[%76, %74], %cst_0 : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %81 = addi %5, %46 : index
      %82 = vector.transfer_read %2[%81, %65], %cst_0 : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %83 = vector.transfer_read %2[%81, %68], %cst_0 : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %84 = vector.transfer_read %2[%81, %71], %cst_0 : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %85 = vector.transfer_read %2[%81, %74], %cst_0 : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %86 = addi %5, %49 : index
      %87 = vector.transfer_read %2[%86, %65], %cst_0 : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %88 = vector.transfer_read %2[%86, %68], %cst_0 : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %89 = vector.transfer_read %2[%86, %71], %cst_0 : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %90 = vector.transfer_read %2[%86, %74], %cst_0 : memref<4096x512xvector<4xf32>>, vector<16x16xf16>
      %91 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %41, %52, %66 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %92 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %42, %59, %91 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %93 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %41, %54, %69 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %94 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %42, %60, %93 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %95 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %41, %56, %72 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %96 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %42, %61, %95 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %97 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %41, %58, %75 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %98 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %42, %62, %97 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %99 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %44, %52, %77 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %100 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %45, %59, %99 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %101 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %44, %54, %78 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %102 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %45, %60, %101 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %103 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %44, %56, %79 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %104 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %45, %61, %103 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %105 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %44, %58, %80 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %106 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %45, %62, %105 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %107 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %47, %52, %82 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %108 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %48, %59, %107 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %109 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %47, %54, %83 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %110 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %48, %60, %109 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %111 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %47, %56, %84 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %112 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %48, %61, %111 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %113 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %47, %58, %85 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %114 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %48, %62, %113 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %115 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %50, %52, %87 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %116 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %51, %59, %115 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %117 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %50, %54, %88 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %118 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %51, %60, %117 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %119 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %50, %56, %89 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %120 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %51, %61, %119 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %121 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %50, %58, %90 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %122 = vector.contract {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction"]} %51, %62, %121 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      vector.transfer_write %92, %2[%63, %65] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>
      vector.transfer_write %94, %2[%63, %68] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>
      vector.transfer_write %96, %2[%63, %71] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>
      vector.transfer_write %98, %2[%63, %74] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>
      vector.transfer_write %100, %2[%76, %65] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>
      vector.transfer_write %102, %2[%76, %68] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>
      vector.transfer_write %104, %2[%76, %71] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>
      vector.transfer_write %106, %2[%76, %74] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>
      vector.transfer_write %108, %2[%81, %65] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>
      vector.transfer_write %110, %2[%81, %68] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>
      vector.transfer_write %112, %2[%81, %71] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>
      vector.transfer_write %114, %2[%81, %74] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>
      vector.transfer_write %116, %2[%86, %65] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>
      vector.transfer_write %118, %2[%86, %68] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>
      vector.transfer_write %120, %2[%86, %71] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>
      vector.transfer_write %122, %2[%86, %74] : vector<16x16xf16>, memref<4096x512xvector<4xf32>>
      dealloc %7 : memref<128x32xf16, 3>
      dealloc %8 : memref<32x128xf16, 3>
    }
    return
  }
  func @matmul_static_shape__num_workgroups__(%arg0: !shapex.ranked_shape<[4096,4096]>, %arg1: !shapex.ranked_shape<[4096,4096]>, %arg2: !shapex.ranked_shape<[4096,4096]>) -> (index, index, index) attributes {sym_visibility = "private"} {
    %c32 = constant 32 : index
    %c1 = constant 1 : index
    return %c32, %c32, %c1 : index, index, index
  }
  hal.interface @legacy_io attributes {sym_visibility = "private"} {
    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
  }
}

The above IR fails to lower to SPIR-V. The repro for the error from the above snippet is

iree-opt -iree-codegen-convert-to-spirv

To reproduce with PR and patch above use this command

iree-opt  -pass-pipeline="iree-codegen-linalg-to-spirv-pipeline{use-vectorization use-workgroup-memory}" matmul_vectorization.mlir
MaheshRavishankar commented 3 years ago

Reopened for now.

MaheshRavishankar commented 3 years ago

Seems to be an issue with masked vector.transfer_reads being generated during tile + promote + vectorize.

IR after each stage

--- After First level of tile+distribute ---                                                                                                             
func @matmul_static_shape() attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}, vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
  %c0 = constant 0 : index                                                                                                                               
  %c32 = constant 32 : index             
  %c4096 = constant 4096 : index                                                                                                                         
  %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16>                                                                                                                                                                                
  %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16>                                                                                                                                                                                
  %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16>
  %3 = "gpu.block_id"() {dimension = "x"} : () -> index
  %4 = "gpu.block_id"() {dimension = "y"} : () -> index
  scf.for %arg0 = %c0 to %c4096 step %c32 {
    %5 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%4]
    %6 = subview %0[%5, %arg0] [128, 32] [1, 1]  : memref<4096x4096xf16> to memref<128x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
    %7 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%3]
    %8 = subview %1[%arg0, %7] [32, 128] [1, 1]  : memref<4096x4096xf16> to memref<32x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
    %9 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%4]
    %10 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%3]
    %11 = subview %2[%9, %10] [128, 128] [1, 1]  : memref<4096x4096xf16> to memref<128x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
    linalg.matmul {__internal_linalg_transform__ = "workgroup", launch_info_key = "__op_num_0__"} ins(%6, %8 : memref<128x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<32x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>) outs(%11 : memref<128x128xf16, affine_map<(d0, d1)[s0
] -> (d0 * 4096 + s0 + d1)>>)
  }
  return
}

--- After Promotion  ---
func @matmul_static_shape() attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}, vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
  %c4096 = constant 4096 : index
  %c32 = constant 32 : index
  %c0 = constant 0 : index
  %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16>
  %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16>
  %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16>
  %3 = "gpu.block_id"() {dimension = "x"} : () -> index
  %4 = "gpu.block_id"() {dimension = "y"} : () -> index
  scf.for %arg0 = %c0 to %c4096 step %c32 {
    %5 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%4]
    %6 = subview %0[%5, %arg0] [128, 32] [1, 1]  : memref<4096x4096xf16> to memref<128x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
    %7 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%3]
    %8 = subview %1[%arg0, %7] [32, 128] [1, 1]  : memref<4096x4096xf16> to memref<32x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
    %9 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%4]
    %10 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%3]
    %11 = subview %2[%9, %10] [128, 128] [1, 1]  : memref<4096x4096xf16> to memref<128x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
    %12 = alloc() : memref<128x32xf16, 3>
    %13 = subview %12[0, 0] [128, 32] [1, 1]  : memref<128x32xf16, 3> to memref<128x32xf16, affine_map<(d0, d1) -> (d0 * 32 + d1)>, 3>
    %14 = alloc() : memref<32x128xf16, 3>
    %15 = subview %14[0, 0] [32, 128] [1, 1]  : memref<32x128xf16, 3> to memref<32x128xf16, affine_map<(d0, d1) -> (d0 * 128 + d1)>, 3>
    linalg.copy(%6, %13) {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<128x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<128x32xf16, affine_map<(d0, d1) -> (d0 * 32 + d1)>, 3>
    linalg.copy(%8, %15) {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<32x128xf16, affine_map<(d0, d1) -> (d0 * 128 + d1)>, 3>
    linalg.matmul {__internal_linalg_transform__ = "workgroup_memory", launch_info_key = "__op_num_0__"} ins(%13, %15 : memref<128x32xf16, affine_map<(d0, d1) -> (d0 * 32 + d1)>, 3>, memref<32x128xf16, affine_map<(d0, d1) -> (d0 * 128 + d1)>, 3>) outs(%11 : memref<128x128xf16, affine_map<(d0, d1)[s0] -> (
d0 * 4096 + s0 + d1)>>)
    dealloc %12 : memref<128x32xf16, 3>
    dealloc %14 : memref<32x128xf16, 3>
  }
  return
}

--- After Second level Tiling  ---
func @matmul_static_shape() attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}, vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
  %c4096 = constant 4096 : index
  %c2 = constant 2 : index
  %c0 = constant 0 : index
  %c32 = constant 32 : index
  %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16>
  %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16>
  %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16>
  %3 = "gpu.block_id"() {dimension = "x"} : () -> index
  %4 = "gpu.block_id"() {dimension = "y"} : () -> index
  scf.for %arg0 = %c0 to %c4096 step %c32 {
    %5 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%4]
    %6 = subview %0[%5, %arg0] [128, 32] [1, 1]  : memref<4096x4096xf16> to memref<128x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
    %7 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%3]
    %8 = subview %1[%arg0, %7] [32, 128] [1, 1]  : memref<4096x4096xf16> to memref<32x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
    %9 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%4]
    %10 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%3]
    %11 = subview %2[%9, %10] [128, 128] [1, 1]  : memref<4096x4096xf16> to memref<128x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
    %12 = alloc() : memref<128x32xf16, 3>                                                                                                                                                                                                                                                                         
    %13 = subview %12[0, 0] [128, 32] [1, 1]  : memref<128x32xf16, 3> to memref<128x32xf16, affine_map<(d0, d1) -> (d0 * 32 + d1)>, 3>                                                                                                                                                                            
    %14 = alloc() : memref<32x128xf16, 3>                                                                                                                                                                                                                                                                         
    %15 = subview %14[0, 0] [32, 128] [1, 1]  : memref<32x128xf16, 3> to memref<32x128xf16, affine_map<(d0, d1) -> (d0 * 128 + d1)>, 3>                                                                                                                                                                           
    linalg.copy(%6, %13) {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<128x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<128x32xf16, affine_map<(d0, d1) -> (d0 * 32 + d1)>, 3>                                                                                      
    linalg.copy(%8, %15) {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<32x128xf16, affine_map<(d0, d1) -> (d0 * 128 + d1)>, 3>                                                                                     
    %16 = gpu.subgroup_id : index                                                                                                                                                                                                                                                                                 
    %17 = divi_signed %16, %c2 : index                                                                                                                                                                                                                                                                            
    %18 = affine.apply affine_map<(d0) -> (d0 * 64 - (d0 floordiv 2) * 128)>(%17)                                                                                                                                                                                                                                 
    %19 = subview %13[%18, 0] [64, 32] [1, 1]  : memref<128x32xf16, affine_map<(d0, d1) -> (d0 * 32 + d1)>, 3> to memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>, 3>                                                                                                                           
    %20 = affine.apply affine_map<(d0) -> (d0 * 64 - (d0 floordiv 2) * 128)>(%16)                                                                                                                                                                                                                                 
    %21 = subview %15[0, %20] [32, 64] [1, 1]  : memref<32x128xf16, affine_map<(d0, d1) -> (d0 * 128 + d1)>, 3> to memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>, 3>                                                                                                                         
    %22 = affine.apply affine_map<(d0) -> (d0 * 64 - (d0 floordiv 2) * 128)>(%17)                                                                                                                                                                                                                                 
    %23 = affine.apply affine_map<(d0) -> (d0 * 64 - (d0 floordiv 2) * 128)>(%16)                                                                                                                                                                                                                                 
    %24 = subview %11[%22, %23] [64, 64] [1, 1]  : memref<128x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>> to memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                                                                                                                 
    linalg.matmul {__internal_linalg_transform__ = "vectorize", launch_info_key = "__op_num_0__"} ins(%19, %21 : memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>, 3>, memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>, 3>) outs(%24 : memref<64x64xf16, affine_map<(d0, d1)[s
0] -> (d0 * 4096 + s0 + d1)>>)                                                                                                                                                                                                                                                                                    
    dealloc %12 : memref<128x32xf16, 3>                                                                                                                                                                                                                                                                           
    dealloc %14 : memref<32x128xf16, 3>                                                                                                                                                                                                                                                                           
  }                                                                                                                                                                                                                                                                                                               
  return                                                                                                                                                                                                                                                                                                          
}                                                                                                                                                                                                                                                                                                                 

--- After Vectorization ---                                                                                                                                                                                                                                                                                       
func @matmul_static_shape() attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}, vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {                                                                                                                              
  %c4096 = constant 4096 : index                                                                                                                                                                                                                                                                                  
  %c2 = constant 2 : index                                                                                                                                                                                                                                                                                        
  %c32 = constant 32 : index                                                                                                                                                                                                                                                                                      
  %c0 = constant 0 : index                                                                                                                                                                                                                                                                                        
  %cst = constant 0.000000e+00 : f16                                                                                                                                                                                                                                                                              
  %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16>                                                                                                                                                                                
  %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16>                                                                                                                                                                                
  %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16>                                                                                                                                                                                
  %3 = "gpu.block_id"() {dimension = "x"} : () -> index                                                                                                                                                                                                                                                           
  %4 = "gpu.block_id"() {dimension = "y"} : () -> index                                                                                                                                                                                                                                                           
  scf.for %arg0 = %c0 to %c4096 step %c32 {                                                                                                                                                                                                                                                                       
    %5 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%4]                                                                                                                                                                                                                                                      
    %6 = subview %0[%5, %arg0] [128, 32] [1, 1]  : memref<4096x4096xf16> to memref<128x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                                                                                                                                                                 
    %7 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%3]                                                                                                                                                                                                                                                      
    %8 = subview %1[%arg0, %7] [32, 128] [1, 1]  : memref<4096x4096xf16> to memref<32x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                                                                                                                                                                 
    %9 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%4]                                                                                                                                                                                                                                                      
    %10 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%3]                                                                                                                                                                                                                                                     
    %11 = subview %2[%9, %10] [128, 128] [1, 1]  : memref<4096x4096xf16> to memref<128x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                                                                                                                                                                
    %12 = alloc() : memref<128x32xf16, 3>                                                                                                                                                                                                                                                                         
    %13 = subview %12[0, 0] [128, 32] [1, 1]  : memref<128x32xf16, 3> to memref<128x32xf16, affine_map<(d0, d1) -> (d0 * 32 + d1)>, 3>                                                                                                                                                                            
    %14 = alloc() : memref<32x128xf16, 3>                                                                                                                                                                                                                                                                         
    %15 = subview %14[0, 0] [32, 128] [1, 1]  : memref<32x128xf16, 3> to memref<32x128xf16, affine_map<(d0, d1) -> (d0 * 128 + d1)>, 3>
    linalg.copy(%6, %13) {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<128x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<128x32xf16, affine_map<(d0, d1) -> (d0 * 32 + d1)>, 3>
    linalg.copy(%8, %15) {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<32x128xf16, affine_map<(d0, d1) -> (d0 * 128 + d1)>, 3>
    %16 = gpu.subgroup_id : index
    %17 = divi_signed %16, %c2 : index
    %18 = affine.apply affine_map<(d0) -> (d0 * 64 - (d0 floordiv 2) * 128)>(%17)
    %19 = subview %13[%18, 0] [64, 32] [1, 1]  : memref<128x32xf16, affine_map<(d0, d1) -> (d0 * 32 + d1)>, 3> to memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>, 3>
    %20 = affine.apply affine_map<(d0) -> (d0 * 64 - (d0 floordiv 2) * 128)>(%16)
    %21 = subview %15[0, %20] [32, 64] [1, 1]  : memref<32x128xf16, affine_map<(d0, d1) -> (d0 * 128 + d1)>, 3> to memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>, 3>
    %22 = affine.apply affine_map<(d0) -> (d0 * 64 - (d0 floordiv 2) * 128)>(%17)
    %23 = affine.apply affine_map<(d0) -> (d0 * 64 - (d0 floordiv 2) * 128)>(%16)
    %24 = subview %11[%22, %23] [64, 64] [1, 1]  : memref<128x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>> to memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
    %25 = vector.transfer_read %19[%c0, %c0], %cst {masked = [false, false]} : memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>, 3>, vector<64x32xf16>
    %26 = vector.transfer_read %21[%c0, %c0], %cst {masked = [false, false]} : memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>, 3>, vector<32x64xf16>
    %27 = vector.transfer_read %24[%c0, %c0], %cst {masked = [false, false]} : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, vector<64x64xf16>
    %28 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %25, %26, %27 : vector<64x32xf16>, vector<32x64xf16> into vector<64x64xf16>
    vector.transfer_write %28, %24[%c0, %c0] {masked = [false, false]} : vector<64x64xf16>, memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
    dealloc %12 : memref<128x32xf16, 3>
    dealloc %14 : memref<32x128xf16, 3>
  }
  return
}

--- After Vector Unroll ---                                                                                                                                                                                                                                                                             [198/1866]
func @matmul_static_shape() attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}, vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {                                                                                                                              
  %c4096 = constant 4096 : index                                                                                                                                                                                                                                                                                  
  %c2 = constant 2 : index                                                                                                                                                                                                                                                                                        
  %c0 = constant 0 : index                                                                                                                                                                                                                                                                                        
  %cst = constant 0.000000e+00 : f16                                                                                                                                                                                                                                                                              
  %c16 = constant 16 : index                                                                                                                                                                                                                                                                                      
  %c32 = constant 32 : index                                                                                                                                                                                                                                                                                      
  %c48 = constant 48 : index                                                                                                                                                                                                                                                                                      
  %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16>                                                                                                                                                                                
  %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16>                                                                                                                                                                                
  %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16>                                                                                                                                                                                
  %3 = "gpu.block_id"() {dimension = "x"} : () -> index                                                                                                                                                                                                                                                           
  %4 = "gpu.block_id"() {dimension = "y"} : () -> index                                                                                                                                                                                                                                                           
  scf.for %arg0 = %c0 to %c4096 step %c32 {                                                                                                                                                                                                                                                                       
    %5 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%4]                                                                                                                                                                                                                                                      
    %6 = subview %0[%5, %arg0] [128, 32] [1, 1]  : memref<4096x4096xf16> to memref<128x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                                                                                                                                                                 
    %7 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%3]                                                                                                                                                                                                                                                      
    %8 = subview %1[%arg0, %7] [32, 128] [1, 1]  : memref<4096x4096xf16> to memref<32x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                                                                                                                                                                 
    %9 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%4]                                                                                                                                                                                                                                                      
    %10 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%3]                                                                                                                                                                                                                                                     
    %11 = subview %2[%9, %10] [128, 128] [1, 1]  : memref<4096x4096xf16> to memref<128x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                                                                                                                                                                
    %12 = alloc() : memref<128x32xf16, 3>                                                                                                                                                                                                                                                                         
    %13 = subview %12[0, 0] [128, 32] [1, 1]  : memref<128x32xf16, 3> to memref<128x32xf16, affine_map<(d0, d1) -> (d0 * 32 + d1)>, 3>                                                                                                                                                                            
    %14 = alloc() : memref<32x128xf16, 3>                                                                                                                                                                                                                                                                         
    %15 = subview %14[0, 0] [32, 128] [1, 1]  : memref<32x128xf16, 3> to memref<32x128xf16, affine_map<(d0, d1) -> (d0 * 128 + d1)>, 3>                                                                                                                                                                           
    linalg.copy(%6, %13) {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<128x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<128x32xf16, affine_map<(d0, d1) -> (d0 * 32 + d1)>, 3>                                                                                      
    linalg.copy(%8, %15) {__internal_linalg_transform__ = "copy_to_workgroup_memory"} : memref<32x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<32x128xf16, affine_map<(d0, d1) -> (d0 * 128 + d1)>, 3>                                                                                     
    %16 = gpu.subgroup_id : index                                                                                                                                                                                                                                                                                 
    %17 = divi_signed %16, %c2 : index                                                                                                                                                                                                                                                                            
    %18 = affine.apply affine_map<(d0) -> (d0 * 64 - (d0 floordiv 2) * 128)>(%17)                                                                                                                                                                                                                                 
    %19 = subview %13[%18, 0] [64, 32] [1, 1]  : memref<128x32xf16, affine_map<(d0, d1) -> (d0 * 32 + d1)>, 3> to memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>, 3>                                                                                                                           
    %20 = affine.apply affine_map<(d0) -> (d0 * 64 - (d0 floordiv 2) * 128)>(%16)                                                                                                                                                                                                                                 
    %21 = subview %15[0, %20] [32, 64] [1, 1]  : memref<32x128xf16, affine_map<(d0, d1) -> (d0 * 128 + d1)>, 3> to memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>, 3>                                                                                                                         
    %22 = affine.apply affine_map<(d0) -> (d0 * 64 - (d0 floordiv 2) * 128)>(%17)                                                                                                                                                                                                                                 
    %23 = affine.apply affine_map<(d0) -> (d0 * 64 - (d0 floordiv 2) * 128)>(%16)                                                                                                                                                                                                                                 
    %24 = subview %11[%22, %23] [64, 64] [1, 1]  : memref<128x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>> to memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                                                                                                                 
    %25 = vector.transfer_read %19[%c0, %c0], %cst {masked = [false, false]} : memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>, 3>, vector<16x16xf16>                                                                                                                                           
    %26 = vector.transfer_read %19[%c0, %c16], %cst {masked = [false, false]} : memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>, 3>, vector<16x16xf16>                                                                                                                                          
    %27 = vector.transfer_read %19[%c16, %c0], %cst {masked = [false, false]} : memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>, 3>, vector<16x16xf16>                                                                                                                                          
    %28 = vector.transfer_read %19[%c16, %c16], %cst {masked = [false, false]} : memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>, 3>, vector<16x16xf16>                                                                                                                                         
    %29 = vector.transfer_read %19[%c32, %c0], %cst {masked = [false, false]} : memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>, 3>, vector<16x16xf16>                                                                                                                                          
    %30 = vector.transfer_read %19[%c32, %c16], %cst {masked = [false, false]} : memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>, 3>, vector<16x16xf16>                                                                                                                                         
    %31 = vector.transfer_read %19[%c48, %c0], %cst {masked = [false, false]} : memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>, 3>, vector<16x16xf16>                                                                                                                                          
    %32 = vector.transfer_read %19[%c48, %c16], %cst {masked = [false, false]} : memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>, 3>, vector<16x16xf16>                                                                                                                                         
    %33 = vector.transfer_read %21[%c0, %c0], %cst {masked = [false, false]} : memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>, 3>, vector<16x16xf16>                                                                                                                                          
    %34 = vector.transfer_read %21[%c0, %c16], %cst {masked = [false, false]} : memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>, 3>, vector<16x16xf16>                                                                                                                                         
    %35 = vector.transfer_read %21[%c0, %c32], %cst {masked = [false, false]} : memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>, 3>, vector<16x16xf16>                                                                                                                                         
    %36 = vector.transfer_read %21[%c0, %c48], %cst {masked = [false, false]} : memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>, 3>, vector<16x16xf16>                                                                                                                                         
    %37 = vector.transfer_read %21[%c16, %c0], %cst {masked = [false, false]} : memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>, 3>, vector<16x16xf16>                                                                                                                                         
    %38 = vector.transfer_read %21[%c16, %c16], %cst {masked = [false, false]} : memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>, 3>, vector<16x16xf16>                                                                                                                                        
    %39 = vector.transfer_read %21[%c16, %c32], %cst {masked = [false, false]} : memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>, 3>, vector<16x16xf16>                                                                                                                                        
    %40 = vector.transfer_read %21[%c16, %c48], %cst {masked = [false, false]} : memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>, 3>, vector<16x16xf16>
    %41 = vector.transfer_read %24[%c0, %c0], %cst {masked = [false, false]} : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, vector<16x16xf16>
    %42 = vector.transfer_read %24[%c0, %c16], %cst {masked = [false, false]} : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, vector<16x16xf16>
    %43 = vector.transfer_read %24[%c0, %c32], %cst {masked = [false, false]} : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, vector<16x16xf16>
    %44 = vector.transfer_read %24[%c0, %c48], %cst {masked = [false, false]} : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, vector<16x16xf16>
    %45 = vector.transfer_read %24[%c16, %c0], %cst {masked = [false, false]} : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, vector<16x16xf16>
    %46 = vector.transfer_read %24[%c16, %c16], %cst {masked = [false, false]} : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, vector<16x16xf16>
    %47 = vector.transfer_read %24[%c16, %c32], %cst {masked = [false, false]} : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, vector<16x16xf16>
    %48 = vector.transfer_read %24[%c16, %c48], %cst {masked = [false, false]} : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, vector<16x16xf16>
    %49 = vector.transfer_read %24[%c32, %c0], %cst {masked = [false, false]} : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, vector<16x16xf16>
    %50 = vector.transfer_read %24[%c32, %c16], %cst {masked = [false, false]} : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, vector<16x16xf16>
    %51 = vector.transfer_read %24[%c32, %c32], %cst {masked = [false, false]} : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, vector<16x16xf16>
    %52 = vector.transfer_read %24[%c32, %c48], %cst {masked = [false, false]} : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, vector<16x16xf16>
    %53 = vector.transfer_read %24[%c48, %c0], %cst {masked = [false, false]} : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, vector<16x16xf16>
    %54 = vector.transfer_read %24[%c48, %c16], %cst {masked = [false, false]} : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, vector<16x16xf16>
    %55 = vector.transfer_read %24[%c48, %c32], %cst {masked = [false, false]} : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, vector<16x16xf16>
    %56 = vector.transfer_read %24[%c48, %c48], %cst {masked = [false, false]} : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, vector<16x16xf16>
    %57 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %25, %33, %41 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
    %58 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %26, %37, %57 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
    %59 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %25, %34, %42 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
    %60 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %26, %38, %59 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
    %61 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %25, %35, %43 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
    %62 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %26, %39, %61 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
    %63 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %25, %36, %44 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
    %64 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %26, %40, %63 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %65 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %27, %33, %45 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %66 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %28, %37, %65 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %67 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %27, %34, %46 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %68 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %28, %38, %67 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %69 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %27, %35, %47 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %70 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %28, %39, %69 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %71 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %27, %36, %48 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %72 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %28, %40, %71 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %73 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %29, %33, %49 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %74 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %30, %37, %73 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %75 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %29, %34, %50 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %76 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %30, %38, %75 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %77 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %29, %35, %51 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %78 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %30, %39, %77 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %79 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %29, %36, %52 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %80 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %30, %40, %79 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %81 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %31, %33, %53 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %82 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %32, %37, %81 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %83 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %31, %34, %54 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %84 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %32, %38, %83 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %85 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %31, %35, %55 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %86 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %32, %39, %85 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %87 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %31, %36, %56 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>                
    %88 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} %32, %40, %87 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
    vector.transfer_write %58, %24[%c0, %c0] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                                                             
    vector.transfer_write %60, %24[%c0, %c16] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                                                                                                                                                 
    vector.transfer_write %62, %24[%c0, %c32] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                                                                                                                                                 
    vector.transfer_write %64, %24[%c0, %c48] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                                                                                                                                                 
    vector.transfer_write %66, %24[%c16, %c0] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                      
    vector.transfer_write %68, %24[%c16, %c16] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                                                                                                                                                
    vector.transfer_write %70, %24[%c16, %c32] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                       
    vector.transfer_write %72, %24[%c16, %c48] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                                                                                                                                                
    vector.transfer_write %74, %24[%c32, %c0] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                                                                                                                                                 
    vector.transfer_write %76, %24[%c32, %c16] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>                                                                                                                                                
    vector.transfer_write %78, %24[%c32, %c32] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>     
    vector.transfer_write %80, %24[%c32, %c48] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>      
    vector.transfer_write %82, %24[%c48, %c0] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>       
    vector.transfer_write %84, %24[%c48, %c16] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>       
    vector.transfer_write %86, %24[%c48, %c32] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>      
    vector.transfer_write %88, %24[%c48, %c48] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>       
    dealloc %12 : memref<128x32xf16, 3>                                                                                                                                                                                                                                                                           
    dealloc %14 : memref<32x128xf16, 3>                                                                                                                                                                                                                                                                           
  }                                                                                                                                                                                                                                                                                                               
  return                                                                                                                                                                                                                                                                                                          
}                                                                                                                                                   

So no masked vector.transfer_reads so far. Will see downstream passes

MaheshRavishankar commented 3 years ago

Issue seems to be at LegalizeStandardToSPIRV

IR before the pass

#map0 = affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>
#map1 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
#map2 = affine_map<(d0, d1) -> (d0 * 128 + d1)>
#map3 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>
#map4 = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>
#map5 = affine_map<(d0, d1) -> (d0, d1)>
#map6 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map7 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map8 = affine_map<(d0, d1, d2) -> (d0, d1)>

module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.5, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, Grou\
pNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers, VariablePointersStorageBuffer, CooperativeMatrixNV], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers, SPV_NV_cooperative_ma\
trix]>, NVIDIA:DiscreteGPU, {cooperative_matrix_properties_nv = [{a_type = i8, b_type = i8, c_type = i32, k_size = 32 : i32, m_size = 8 : i32, n_size = 8 : i32, result_type = i32, scope = 3 : i32}, {a_type = f16, b_type = f16, c_type = f16, k_size = 16 : i32, m_size = 16 : i32, n_size = 16 : i32, result_\
type = f16, scope = 3 : i32}, {a_type = f16, b_type = f16, c_type = f32, k_size = 16 : i32, m_size = 16 : i32, n_size = 16 : i32, result_type = f32, scope = 3 : i32}], max_compute_shared_memory_size = 49152 : i32, max_compute_workgroup_invocations = 1024 : i32, max_compute_workgroup_size = dense<[2147483\
647, 65535, 65535]> : vector<3xi32>, subgroup_size = 32 : i32}>} {
  func @matmul_static_shape() attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}, vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
    %cst = constant 0.000000e+00 : f16
    %c16 = constant 16 : index
    %c48 = constant 48 : index
    %c32 = constant 32 : index
    %c4096 = constant 4096 : index
    %c128 = constant 128 : index
    %c64 = constant 64 : index
    %c2 = constant 2 : index
    %c0 = constant 0 : index
    %c-1 = constant -1 : index
    %c-128 = constant -128 : index
    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16>
    %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16>
    %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16>
    %3 = "gpu.block_id"() {dimension = "x"} : () -> index
    %4 = "gpu.block_id"() {dimension = "y"} : () -> index
    scf.for %arg0 = %c0 to %c4096 step %c32 {
      %5 = muli %4, %c128 : index
      %6 = subview %0[%5, %arg0] [128, 32] [1, 1]  : memref<4096x4096xf16> to memref<128x32xf16, #map0>
      %7 = muli %3, %c128 : index
      %8 = subview %1[%arg0, %7] [32, 128] [1, 1]  : memref<4096x4096xf16> to memref<32x128xf16, #map0>
      %9 = subview %2[%5, %7] [128, 128] [1, 1]  : memref<4096x4096xf16> to memref<128x128xf16, #map0>
      %10 = alloc() : memref<128x32xf16, 3>
      %11 = subview %10[0, 0] [128, 32] [1, 1]  : memref<128x32xf16, 3> to memref<128x32xf16, #map1, 3>
      %12 = alloc() : memref<32x128xf16, 3>
      %13 = subview %12[0, 0] [32, 128] [1, 1]  : memref<32x128xf16, 3> to memref<32x128xf16, #map2, 3>
      %14 = "gpu.thread_id"() {dimension = "x"} : () -> index
      %15 = "gpu.block_dim"() {dimension = "x"} : () -> index
      %16 = "gpu.thread_id"() {dimension = "y"} : () -> index
      %17 = "gpu.block_dim"() {dimension = "y"} : () -> index
      %18 = "gpu.thread_id"() {dimension = "z"} : () -> index
      %19 = "gpu.block_dim"() {dimension = "z"} : () -> index
      %20 = muli %18, %17 : index
      %21 = addi %20, %16 : index
      %22 = muli %19, %17 : index
      %23 = muli %21, %15 : index
      %24 = addi %23, %14 : index
      %25 = muli %22, %15 : index
      scf.for %arg1 = %24 to %c4096 step %25 {
        %113 = divi_signed %arg1, %c32 : index
        %114 = remi_signed %arg1, %c32 : index
        %115 = load %6[%113, %114] : memref<128x32xf16, #map0>
        store %115, %11[%113, %114] : memref<128x32xf16, #map1, 3>
      }
      spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
      scf.for %arg1 = %24 to %c4096 step %25 {
        %113 = divi_signed %arg1, %c128 : index
        %114 = remi_signed %arg1, %c128 : index
        %115 = load %8[%113, %114] : memref<32x128xf16, #map0>
        store %115, %13[%113, %114] : memref<32x128xf16, #map2, 3>
      }
      spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
      %26 = gpu.subgroup_id : index
      %27 = divi_signed %26, %c2 : index
      %28 = muli %27, %c64 : index
      %29 = cmpi "slt", %27, %c0 : index
      %30 = subi %c-1, %27 : index
      %31 = select %29, %30, %27 : index
      %32 = divi_signed %31, %c2 : index
      %33 = subi %c-1, %32 : index
      %34 = select %29, %33, %32 : index
      %35 = muli %34, %c-128 : index
      %36 = addi %28, %35 : index
      %37 = subview %11[%36, 0] [64, 32] [1, 1]  : memref<128x32xf16, #map1, 3> to memref<64x32xf16, #map3, 3>
      %38 = muli %26, %c64 : index
      %39 = cmpi "slt", %26, %c0 : index
      %40 = subi %c-1, %26 : index
      %41 = select %39, %40, %26 : index
      %42 = divi_signed %41, %c2 : index
      %43 = subi %c-1, %42 : index
      %44 = select %39, %43, %42 : index
      %45 = muli %44, %c-128 : index
      %46 = addi %38, %45 : index
      %47 = subview %13[0, %46] [32, 64] [1, 1]  : memref<32x128xf16, #map2, 3> to memref<32x64xf16, #map4, 3>
      %48 = subview %9[%36, %46] [64, 64] [1, 1]  : memref<128x128xf16, #map0> to memref<64x64xf16, #map0>
      %49 = vector.transfer_read %37[%c0, %c0], %cst {masked = [false, false]} : memref<64x32xf16, #map3, 3>, vector<16x16xf16>
      %50 = vector.transfer_read %37[%c0, %c16], %cst {masked = [false, false]} : memref<64x32xf16, #map3, 3>, vector<16x16xf16>
      %51 = vector.transfer_read %37[%c16, %c0], %cst {masked = [false, false]} : memref<64x32xf16, #map3, 3>, vector<16x16xf16>
      %52 = vector.transfer_read %37[%c16, %c16], %cst {masked = [false, false]} : memref<64x32xf16, #map3, 3>, vector<16x16xf16>
      %53 = vector.transfer_read %37[%c32, %c0], %cst {masked = [false, false]} : memref<64x32xf16, #map3, 3>, vector<16x16xf16>
      %54 = vector.transfer_read %37[%c32, %c16], %cst {masked = [false, false]} : memref<64x32xf16, #map3, 3>, vector<16x16xf16>
      %55 = vector.transfer_read %37[%c48, %c0], %cst {masked = [false, false]} : memref<64x32xf16, #map3, 3>, vector<16x16xf16>
      %56 = vector.transfer_read %37[%c48, %c16], %cst {masked = [false, false]} : memref<64x32xf16, #map3, 3>, vector<16x16xf16>
      %57 = vector.transfer_read %47[%c0, %c0], %cst {masked = [false, false]} : memref<32x64xf16, #map4, 3>, vector<16x16xf16>
      %58 = vector.transfer_read %47[%c0, %c16], %cst {masked = [false, false]} : memref<32x64xf16, #map4, 3>, vector<16x16xf16>
      %59 = vector.transfer_read %47[%c0, %c32], %cst {masked = [false, false]} : memref<32x64xf16, #map4, 3>, vector<16x16xf16>
      %60 = vector.transfer_read %47[%c0, %c48], %cst {masked = [false, false]} : memref<32x64xf16, #map4, 3>, vector<16x16xf16>
      %61 = vector.transfer_read %47[%c16, %c0], %cst {masked = [false, false]} : memref<32x64xf16, #map4, 3>, vector<16x16xf16>
      %62 = vector.transfer_read %47[%c16, %c16], %cst {masked = [false, false]} : memref<32x64xf16, #map4, 3>, vector<16x16xf16>
      %63 = vector.transfer_read %47[%c16, %c32], %cst {masked = [false, false]} : memref<32x64xf16, #map4, 3>, vector<16x16xf16>
      %64 = vector.transfer_read %47[%c16, %c48], %cst {masked = [false, false]} : memref<32x64xf16, #map4, 3>, vector<16x16xf16>
      %65 = vector.transfer_read %48[%c0, %c0], %cst {masked = [false, false]} : memref<64x64xf16, #map0>, vector<16x16xf16>
      %66 = vector.transfer_read %48[%c0, %c16], %cst {masked = [false, false]} : memref<64x64xf16, #map0>, vector<16x16xf16>
      %67 = vector.transfer_read %48[%c0, %c32], %cst {masked = [false, false]} : memref<64x64xf16, #map0>, vector<16x16xf16>
      %68 = vector.transfer_read %48[%c0, %c48], %cst {masked = [false, false]} : memref<64x64xf16, #map0>, vector<16x16xf16>
      %69 = vector.transfer_read %48[%c16, %c0], %cst {masked = [false, false]} : memref<64x64xf16, #map0>, vector<16x16xf16>
      %70 = vector.transfer_read %48[%c16, %c16], %cst {masked = [false, false]} : memref<64x64xf16, #map0>, vector<16x16xf16>
      %71 = vector.transfer_read %48[%c16, %c32], %cst {masked = [false, false]} : memref<64x64xf16, #map0>, vector<16x16xf16>
      %72 = vector.transfer_read %48[%c16, %c48], %cst {masked = [false, false]} : memref<64x64xf16, #map0>, vector<16x16xf16>
      %73 = vector.transfer_read %48[%c32, %c0], %cst {masked = [false, false]} : memref<64x64xf16, #map0>, vector<16x16xf16>
      %74 = vector.transfer_read %48[%c32, %c16], %cst {masked = [false, false]} : memref<64x64xf16, #map0>, vector<16x16xf16>
      %75 = vector.transfer_read %48[%c32, %c32], %cst {masked = [false, false]} : memref<64x64xf16, #map0>, vector<16x16xf16>
      %76 = vector.transfer_read %48[%c32, %c48], %cst {masked = [false, false]} : memref<64x64xf16, #map0>, vector<16x16xf16>
      %77 = vector.transfer_read %48[%c48, %c0], %cst {masked = [false, false]} : memref<64x64xf16, #map0>, vector<16x16xf16>
      %78 = vector.transfer_read %48[%c48, %c16], %cst {masked = [false, false]} : memref<64x64xf16, #map0>, vector<16x16xf16>
      %79 = vector.transfer_read %48[%c48, %c32], %cst {masked = [false, false]} : memref<64x64xf16, #map0>, vector<16x16xf16>
      %80 = vector.transfer_read %48[%c48, %c48], %cst {masked = [false, false]} : memref<64x64xf16, #map0>, vector<16x16xf16>
      %81 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %49, %57, %65 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %82 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %50, %61, %81 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %83 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %49, %58, %66 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %84 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %50, %62, %83 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %85 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %49, %59, %67 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %86 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %50, %63, %85 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %87 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %49, %60, %68 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %88 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %50, %64, %87 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %89 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %51, %57, %69 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %90 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %52, %61, %89 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %91 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %51, %58, %70 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %92 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %52, %62, %91 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %93 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %51, %59, %71 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %94 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %52, %63, %93 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %95 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %51, %60, %72 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %96 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %52, %64, %95 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %97 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %53, %57, %73 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %98 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %54, %61, %97 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %99 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %53, %58, %74 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %100 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %54, %62, %99 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %101 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %53, %59, %75 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %102 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %54, %63, %101 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %103 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %53, %60, %76 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %104 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %54, %64, %103 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %105 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %55, %57, %77 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %106 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %56, %61, %105 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %107 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %55, %58, %78 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %108 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %56, %62, %107 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %109 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %55, %59, %79 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %110 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %56, %63, %109 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %111 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %55, %60, %80 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %112 = vector.contract {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} %56, %64, %111 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      vector.transfer_write %82, %48[%c0, %c0] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, #map0>
      vector.transfer_write %84, %48[%c0, %c16] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, #map0>
      vector.transfer_write %86, %48[%c0, %c32] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, #map0>
      vector.transfer_write %88, %48[%c0, %c48] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, #map0>
      vector.transfer_write %90, %48[%c16, %c0] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, #map0>
      vector.transfer_write %92, %48[%c16, %c16] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, #map0>
      vector.transfer_write %94, %48[%c16, %c32] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, #map0>
      vector.transfer_write %96, %48[%c16, %c48] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, #map0>
      vector.transfer_write %98, %48[%c32, %c0] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, #map0>
      vector.transfer_write %100, %48[%c32, %c16] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, #map0>
      vector.transfer_write %102, %48[%c32, %c32] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, #map0>
      vector.transfer_write %104, %48[%c32, %c48] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, #map0>
      vector.transfer_write %106, %48[%c48, %c0] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, #map0>
      vector.transfer_write %108, %48[%c48, %c16] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, #map0>
      vector.transfer_write %110, %48[%c48, %c32] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, #map0>
      vector.transfer_write %112, %48[%c48, %c48] {masked = [false, false]} : vector<16x16xf16>, memref<64x64xf16, #map0>
      dealloc %10 : memref<128x32xf16, 3>
      dealloc %12 : memref<32x128xf16, 3>
    }
    return
  }
  func @matmul_static_shape__num_workgroups__(%arg0: !shapex.ranked_shape<[4096,4096]>, %arg1: !shapex.ranked_shape<[4096,4096]>, %arg2: !shapex.ranked_shape<[4096,4096]>) -> (index, index, index) attributes {sym_visibility = "private"} {
    %c32 = constant 32 : index
    %c1 = constant 1 : index
    return %c32, %c32, %c1 : index, index, index
  }
  hal.interface @legacy_io attributes {sym_visibility = "private"} {
    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
  }
}

IR after the pass

#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>

module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.5, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, Grou\
pNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers, VariablePointersStorageBuffer, CooperativeMatrixNV], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers, SPV_NV_cooperative_ma\
trix]>, NVIDIA:DiscreteGPU, {cooperative_matrix_properties_nv = [{a_type = i8, b_type = i8, c_type = i32, k_size = 32 : i32, m_size = 8 : i32, n_size = 8 : i32, result_type = i32, scope = 3 : i32}, {a_type = f16, b_type = f16, c_type = f16, k_size = 16 : i32, m_size = 16 : i32, n_size = 16 : i32, result_\
type = f16, scope = 3 : i32}, {a_type = f16, b_type = f16, c_type = f32, k_size = 16 : i32, m_size = 16 : i32, n_size = 16 : i32, result_type = f32, scope = 3 : i32}], max_compute_shared_memory_size = 49152 : i32, max_compute_workgroup_invocations = 1024 : i32, max_compute_workgroup_size = dense<[2147483\
647, 65535, 65535]> : vector<3xi32>, subgroup_size = 32 : i32}>} {
  func @matmul_static_shape() attributes {spv.entry_point_abi = {local_size = dense<[128, 1, 1]> : vector<3xi32>}, vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
    %c16 = constant 16 : index
    %c48 = constant 48 : index
    %c32 = constant 32 : index
    %c4096 = constant 4096 : index
    %c128 = constant 128 : index
    %c64 = constant 64 : index
    %c2 = constant 2 : index
    %c0 = constant 0 : index
    %c-1 = constant -1 : index
    %c-128 = constant -128 : index
    %cst = constant 0.000000e+00 : f16
    %0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16>
    %1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16>
    %2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16>
    %3 = "gpu.block_id"() {dimension = "x"} : () -> index
    %4 = "gpu.block_id"() {dimension = "y"} : () -> index
    scf.for %arg0 = %c0 to %c4096 step %c32 {
      %5 = muli %4, %c128 : index
      %6 = muli %3, %c128 : index
      %7 = alloc() : memref<128x32xf16, 3>
      %8 = alloc() : memref<32x128xf16, 3>
      %9 = "gpu.thread_id"() {dimension = "x"} : () -> index
      %10 = "gpu.block_dim"() {dimension = "x"} : () -> index
      %11 = "gpu.thread_id"() {dimension = "y"} : () -> index
      %12 = "gpu.block_dim"() {dimension = "y"} : () -> index
      %13 = "gpu.thread_id"() {dimension = "z"} : () -> index
      %14 = "gpu.block_dim"() {dimension = "z"} : () -> index
      %15 = muli %13, %12 : index
      %16 = addi %15, %11 : index
      %17 = muli %14, %12 : index
      %18 = muli %16, %10 : index
      %19 = addi %18, %9 : index
      %20 = muli %17, %10 : index
      scf.for %arg1 = %19 to %c4096 step %20 {
        %229 = divi_signed %arg1, %c32 : index
        %230 = remi_signed %arg1, %c32 : index
        %231 = addi %5, %229 : index
        %232 = addi %arg0, %230 : index
        %233 = load %0[%231, %232] : memref<4096x4096xf16>
        store %233, %7[%229, %230] : memref<128x32xf16, 3>
      }
      spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
      scf.for %arg1 = %19 to %c4096 step %20 {
        %229 = divi_signed %arg1, %c128 : index
        %230 = remi_signed %arg1, %c128 : index
        %231 = addi %arg0, %229 : index
        %232 = addi %6, %230 : index
        %233 = load %1[%231, %232] : memref<4096x4096xf16>
        store %233, %8[%229, %230] : memref<32x128xf16, 3>
      }
      spv.ControlBarrier "Workgroup", "Workgroup", "AcquireRelease"
      %21 = gpu.subgroup_id : index
      %22 = divi_signed %21, %c2 : index
      %23 = muli %22, %c64 : index
      %24 = cmpi "slt", %22, %c0 : index
      %25 = subi %c-1, %22 : index
      %26 = select %24, %25, %22 : index
      %27 = divi_signed %26, %c2 : index
      %28 = subi %c-1, %27 : index
      %29 = select %24, %28, %27 : index
      %30 = muli %29, %c-128 : index
      %31 = addi %23, %30 : index
      %32 = muli %21, %c64 : index
      %33 = cmpi "slt", %21, %c0 : index
      %34 = subi %c-1, %21 : index
      %35 = select %33, %34, %21 : index
      %36 = divi_signed %35, %c2 : index
      %37 = subi %c-1, %36 : index
      %38 = select %33, %37, %36 : index
      %39 = muli %38, %c-128 : index
      %40 = addi %32, %39 : index
      %41 = vector.transfer_read %7[%31, %c0], %cst {masked = [true, false]} : memref<128x32xf16, 3>, vector<16x16xf16>
      %42 = vector.transfer_read %7[%31, %c16], %cst {masked = [true, false]} : memref<128x32xf16, 3>, vector<16x16xf16>
      %43 = addi %31, %c16 : index
      %44 = vector.transfer_read %7[%43, %c0], %cst {masked = [true, false]} : memref<128x32xf16, 3>, vector<16x16xf16>
      %45 = addi %31, %c16 : index
      %46 = vector.transfer_read %7[%45, %c16], %cst {masked = [true, false]} : memref<128x32xf16, 3>, vector<16x16xf16>
      %47 = addi %31, %c32 : index
      %48 = vector.transfer_read %7[%47, %c0], %cst {masked = [true, false]} : memref<128x32xf16, 3>, vector<16x16xf16>
      %49 = addi %31, %c32 : index
      %50 = vector.transfer_read %7[%49, %c16], %cst {masked = [true, false]} : memref<128x32xf16, 3>, vector<16x16xf16>
      %51 = addi %31, %c48 : index
      %52 = vector.transfer_read %7[%51, %c0], %cst {masked = [true, false]} : memref<128x32xf16, 3>, vector<16x16xf16>
      %53 = addi %31, %c48 : index
      %54 = vector.transfer_read %7[%53, %c16], %cst {masked = [true, false]} : memref<128x32xf16, 3>, vector<16x16xf16>
      %55 = vector.transfer_read %8[%c0, %40], %cst {masked = [false, true]} : memref<32x128xf16, 3>, vector<16x16xf16>
      %56 = addi %40, %c16 : index
      %57 = vector.transfer_read %8[%c0, %56], %cst {masked = [false, true]} : memref<32x128xf16, 3>, vector<16x16xf16>
      %58 = addi %40, %c32 : index
      %59 = vector.transfer_read %8[%c0, %58], %cst {masked = [false, true]} : memref<32x128xf16, 3>, vector<16x16xf16>
      %60 = addi %40, %c48 : index
      %61 = vector.transfer_read %8[%c0, %60], %cst {masked = [false, true]} : memref<32x128xf16, 3>, vector<16x16xf16>
      %62 = vector.transfer_read %8[%c16, %40], %cst {masked = [false, true]} : memref<32x128xf16, 3>, vector<16x16xf16>
      %63 = addi %40, %c16 : index
      %64 = vector.transfer_read %8[%c16, %63], %cst {masked = [false, true]} : memref<32x128xf16, 3>, vector<16x16xf16>
      %65 = addi %40, %c32 : index
      %66 = vector.transfer_read %8[%c16, %65], %cst {masked = [false, true]} : memref<32x128xf16, 3>, vector<16x16xf16>
      %67 = addi %40, %c48 : index
      %68 = vector.transfer_read %8[%c16, %67], %cst {masked = [false, true]} : memref<32x128xf16, 3>, vector<16x16xf16>
      %69 = addi %5, %31 : index
      %70 = addi %6, %40 : index
      %71 = vector.transfer_read %2[%69, %70], %cst : memref<4096x4096xf16>, vector<16x16xf16>
      %72 = addi %40, %c16 : index
      %73 = addi %5, %31 : index
      %74 = addi %6, %72 : index
      %75 = vector.transfer_read %2[%73, %74], %cst : memref<4096x4096xf16>, vector<16x16xf16>
      %76 = addi %40, %c32 : index
      %77 = addi %5, %31 : index
      %78 = addi %6, %76 : index
      %79 = vector.transfer_read %2[%77, %78], %cst : memref<4096x4096xf16>, vector<16x16xf16>
      %80 = addi %40, %c48 : index
      %81 = addi %5, %31 : index
      %82 = addi %6, %80 : index
      %83 = vector.transfer_read %2[%81, %82], %cst : memref<4096x4096xf16>, vector<16x16xf16>
      %84 = addi %31, %c16 : index
      %85 = addi %5, %84 : index
      %86 = addi %6, %40 : index
      %87 = vector.transfer_read %2[%85, %86], %cst : memref<4096x4096xf16>, vector<16x16xf16>
      %88 = addi %31, %c16 : index
      %89 = addi %40, %c16 : index
      %90 = addi %5, %88 : index
      %91 = addi %6, %89 : index
      %92 = vector.transfer_read %2[%90, %91], %cst : memref<4096x4096xf16>, vector<16x16xf16>
      %93 = addi %31, %c16 : index
      %94 = addi %40, %c32 : index
      %95 = addi %5, %93 : index
      %96 = addi %6, %94 : index
      %97 = vector.transfer_read %2[%95, %96], %cst : memref<4096x4096xf16>, vector<16x16xf16>
      %98 = addi %31, %c16 : index
      %99 = addi %40, %c48 : index
      %100 = addi %5, %98 : index
      %101 = addi %6, %99 : index
      %102 = vector.transfer_read %2[%100, %101], %cst : memref<4096x4096xf16>, vector<16x16xf16>
      %103 = addi %31, %c32 : index
      %104 = addi %5, %103 : index
      %105 = addi %6, %40 : index
      %106 = vector.transfer_read %2[%104, %105], %cst : memref<4096x4096xf16>, vector<16x16xf16>
      %107 = addi %31, %c32 : index
      %108 = addi %40, %c16 : index
      %109 = addi %5, %107 : index
      %110 = addi %6, %108 : index
      %111 = vector.transfer_read %2[%109, %110], %cst : memref<4096x4096xf16>, vector<16x16xf16>
      %112 = addi %31, %c32 : index
      %113 = addi %40, %c32 : index
      %114 = addi %5, %112 : index
      %115 = addi %6, %113 : index
      %116 = vector.transfer_read %2[%114, %115], %cst : memref<4096x4096xf16>, vector<16x16xf16>
      %117 = addi %31, %c32 : index
      %118 = addi %40, %c48 : index
      %119 = addi %5, %117 : index
      %120 = addi %6, %118 : index
      %121 = vector.transfer_read %2[%119, %120], %cst : memref<4096x4096xf16>, vector<16x16xf16>
      %122 = addi %31, %c48 : index
      %123 = addi %5, %122 : index
      %124 = addi %6, %40 : index
      %125 = vector.transfer_read %2[%123, %124], %cst : memref<4096x4096xf16>, vector<16x16xf16>
      %126 = addi %31, %c48 : index
      %127 = addi %40, %c16 : index
      %128 = addi %5, %126 : index
      %129 = addi %6, %127 : index
      %130 = vector.transfer_read %2[%128, %129], %cst : memref<4096x4096xf16>, vector<16x16xf16>
      %131 = addi %31, %c48 : index
      %132 = addi %40, %c32 : index
      %133 = addi %5, %131 : index
      %134 = addi %6, %132 : index
      %135 = vector.transfer_read %2[%133, %134], %cst : memref<4096x4096xf16>, vector<16x16xf16>
      %136 = addi %31, %c48 : index
      %137 = addi %40, %c48 : index
      %138 = addi %5, %136 : index
      %139 = addi %6, %137 : index
      %140 = vector.transfer_read %2[%138, %139], %cst : memref<4096x4096xf16>, vector<16x16xf16>
      %141 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %41, %55, %71 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %142 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %42, %62, %141 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %143 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %41, %57, %75 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %144 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %42, %64, %143 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %145 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %41, %59, %79 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %146 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %42, %66, %145 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %147 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %41, %61, %83 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %148 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %42, %68, %147 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %149 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %44, %55, %87 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %150 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %46, %62, %149 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %151 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %44, %57, %92 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %152 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %46, %64, %151 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %153 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %44, %59, %97 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %154 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %46, %66, %153 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %155 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %44, %61, %102 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %156 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %46, %68, %155 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %157 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %48, %55, %106 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %158 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %50, %62, %157 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %159 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %48, %57, %111 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %160 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %50, %64, %159 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %161 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %48, %59, %116 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %162 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %50, %66, %161 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %163 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %48, %61, %121 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %164 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %50, %68, %163 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %165 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %52, %55, %125 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %166 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %54, %62, %165 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %167 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %52, %57, %130 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %168 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %54, %64, %167 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %169 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %52, %59, %135 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %170 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %54, %66, %169 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %171 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %52, %61, %140 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %172 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} %54, %68, %171 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
      %173 = addi %5, %31 : index
      %174 = addi %6, %40 : index
      vector.transfer_write %142, %2[%173, %174] : vector<16x16xf16>, memref<4096x4096xf16>
      %175 = addi %40, %c16 : index
      %176 = addi %5, %31 : index
      %177 = addi %6, %175 : index
      vector.transfer_write %144, %2[%176, %177] : vector<16x16xf16>, memref<4096x4096xf16>
      %178 = addi %40, %c32 : index
      %179 = addi %5, %31 : index
      %180 = addi %6, %178 : index
      vector.transfer_write %146, %2[%179, %180] : vector<16x16xf16>, memref<4096x4096xf16>
      %181 = addi %40, %c48 : index
      %182 = addi %5, %31 : index
      %183 = addi %6, %181 : index
      vector.transfer_write %148, %2[%182, %183] : vector<16x16xf16>, memref<4096x4096xf16>
      %184 = addi %31, %c16 : index
      %185 = addi %5, %184 : index
      %186 = addi %6, %40 : index
      vector.transfer_write %150, %2[%185, %186] : vector<16x16xf16>, memref<4096x4096xf16>
      %187 = addi %31, %c16 : index
      %188 = addi %40, %c16 : index
      %189 = addi %5, %187 : index
      %190 = addi %6, %188 : index
      vector.transfer_write %152, %2[%189, %190] : vector<16x16xf16>, memref<4096x4096xf16>
      %191 = addi %31, %c16 : index
      %192 = addi %40, %c32 : index
      %193 = addi %5, %191 : index
      %194 = addi %6, %192 : index
      vector.transfer_write %154, %2[%193, %194] : vector<16x16xf16>, memref<4096x4096xf16>
      %195 = addi %31, %c16 : index
      %196 = addi %40, %c48 : index
      %197 = addi %5, %195 : index
      %198 = addi %6, %196 : index
      vector.transfer_write %156, %2[%197, %198] : vector<16x16xf16>, memref<4096x4096xf16>
      %199 = addi %31, %c32 : index
      %200 = addi %5, %199 : index
      %201 = addi %6, %40 : index
      vector.transfer_write %158, %2[%200, %201] : vector<16x16xf16>, memref<4096x4096xf16>
      %202 = addi %31, %c32 : index
      %203 = addi %40, %c16 : index
      %204 = addi %5, %202 : index
      %205 = addi %6, %203 : index
      vector.transfer_write %160, %2[%204, %205] : vector<16x16xf16>, memref<4096x4096xf16>
      %206 = addi %31, %c32 : index
      %207 = addi %40, %c32 : index
      %208 = addi %5, %206 : index
      %209 = addi %6, %207 : index
      vector.transfer_write %162, %2[%208, %209] : vector<16x16xf16>, memref<4096x4096xf16>
      %210 = addi %31, %c32 : index
      %211 = addi %40, %c48 : index
      %212 = addi %5, %210 : index
      %213 = addi %6, %211 : index
      vector.transfer_write %164, %2[%212, %213] : vector<16x16xf16>, memref<4096x4096xf16>
      %214 = addi %31, %c48 : index
      %215 = addi %5, %214 : index
      %216 = addi %6, %40 : index
      vector.transfer_write %166, %2[%215, %216] : vector<16x16xf16>, memref<4096x4096xf16>
      %217 = addi %31, %c48 : index
      %218 = addi %40, %c16 : index
      %219 = addi %5, %217 : index
      %220 = addi %6, %218 : index
      vector.transfer_write %168, %2[%219, %220] : vector<16x16xf16>, memref<4096x4096xf16>
      %221 = addi %31, %c48 : index
      %222 = addi %40, %c32 : index
      %223 = addi %5, %221 : index
      %224 = addi %6, %222 : index
      vector.transfer_write %170, %2[%223, %224] : vector<16x16xf16>, memref<4096x4096xf16>
      %225 = addi %31, %c48 : index
      %226 = addi %40, %c48 : index
      %227 = addi %5, %225 : index
      %228 = addi %6, %226 : index
      vector.transfer_write %172, %2[%227, %228] : vector<16x16xf16>, memref<4096x4096xf16>
      dealloc %7 : memref<128x32xf16, 3>
      dealloc %8 : memref<32x128xf16, 3>
    }
    return
  }
  func @matmul_static_shape__num_workgroups__(%arg0: !shapex.ranked_shape<[4096,4096]>, %arg1: !shapex.ranked_shape<[4096,4096]>, %arg2: !shapex.ranked_shape<[4096,4096]>) -> (index, index, index) attributes {sym_visibility = "private"} {
    %c32 = constant 32 : index
    %c1 = constant 1 : index
    return %c32, %c32, %c1 : index, index, index
  }
  hal.interface @legacy_io attributes {sym_visibility = "private"} {
    hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
    hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
    hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
  }
}
ThomasRaoux commented 3 years ago

Fix under review: https://reviews.llvm.org/D89907

ThomasRaoux commented 3 years ago

Fix landed in LLVM, I'll close this once it gets integrated in IREE