nod-ai / iree-amd-aie

IREE plugin repository for the AMD AIE accelerator
Apache License 2.0
68 stars 29 forks source link

[Conv2d] stride 2 test failed to legalize operation 'vector.extract_strided_slice' #581

Open yzhang93 opened 3 months ago

yzhang93 commented 3 months ago

Conv2d with stride = 2 example:

func.func @conv_2d_nhwc_hwcf(%arg0: tensor<1x129x129x16xi8>, %arg1: tensor<3x3x16x32xi8>) -> tensor<1x64x64x32xi32> {
  %cst = arith.constant 0 : i32
  %0 = tensor.empty() : tensor<1x64x64x32xi32>
  %1 = linalg.fill ins(%cst : i32) outs(%0 : tensor<1x64x64x32xi32>) -> tensor<1x64x64x32xi32>
  %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins(%arg0, %arg1 : tensor<1x129x129x16xi8>, tensor<3x3x16x32xi8>) outs(%1 : tensor<1x64x64x32xi32>) -> tensor<1x64x64x32xi32>
  return %2 : tensor<1x64x64x32xi32>
}

Error:

error: failed to legalize operation 'vector.extract_strided_slice' that was explicitly marked illegal
 %13 = vector.extract_strided_slice %10 {offsets = [0, 0, 0], sizes = [1, 1, 8], strides = [1, 1, 1]} : vector<1x7x8xi8> to vector<1x1x8xi8>

Generated IR snippet:

%10 = vector.transfer_read %buf0[%c0, %c0, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<1x1x7x8xi8>, vector<1x7x8xi8>
%11 = vector.transfer_read %buf1[%c0, %c0, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<1x1x8x8xi8>, vector<1x8x8xi8>
%12 = vector.transfer_read %buf4[%c0, %c0, %c0, %c0], %c0_i32 {in_bounds = [true, true, true]} : memref<1x1x4x8xi32>, vector<1x4x8xi32>
%13 = vector.extract_strided_slice %10 {offsets = [0, 0, 0], sizes = [1, 1, 8], strides = [1, 1, 1]} : vector<1x7x8xi8> to vector<1x1x8xi8>
%14 = vector.extract_strided_slice %10 {offsets = [0, 2, 0], sizes = [1, 1, 8], strides = [1, 1, 1]} : vector<1x7x8xi8> to vector<1x1x8xi8>
%15 = vector.extract_strided_slice %10 {offsets = [0, 4, 0], sizes = [1, 1, 8], strides = [1, 1, 1]} : vector<1x7x8xi8> to vector<1x1x8xi8>
%16 = vector.extract_strided_slice %10 {offsets = [0, 6, 0], sizes = [1, 1, 8], strides = [1, 1, 1]} : vector<1x7x8xi8> to vector<1x1x8xi8>
%17 = vector.extract %11[0] : vector<8x8xi8> from vector<1x8x8xi8>
%18 = vector.extract_strided_slice %12 {offsets = [0, 0, 0], sizes = [1, 1, 8], strides = [1, 1, 1]} : vector<1x4x8xi32> to vector<1x1x8xi32>
%19 = vector.extract_strided_slice %12 {offsets = [0, 1, 0], sizes = [1, 1, 8], strides = [1, 1, 1]} : vector<1x4x8xi32> to vector<1x1x8xi32>
%20 = vector.extract_strided_slice %12 {offsets = [0, 2, 0], sizes = [1, 1, 8], strides = [1, 1, 1]} : vector<1x4x8xi32> to vector<1x1x8xi32>
%21 = vector.extract_strided_slice %12 {offsets = [0, 3, 0], sizes = [1, 1, 8], strides = [1, 1, 1]} : vector<1x4x8xi32> to vector<1x1x8xi32>
%22 = arith.extsi %13 : vector<1x1x8xi8> to vector<1x1x8xi32>
%23 = arith.extsi %17 : vector<8x8xi8> to vector<8x8xi32>
%24 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %22, %23, %18 : vector<1x1x8xi32>, vector<8x8xi32> into vector<1x1x8xi32>
%25 = arith.extsi %14 : vector<1x1x8xi8> to vector<1x1x8xi32>
%26 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %25, %23, %19 : vector<1x1x8xi32>, vector<8x8xi32> into vector<1x1x8xi32>
%27 = arith.extsi %15 : vector<1x1x8xi8> to vector<1x1x8xi32>
%28 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %27, %23, %20 : vector<1x1x8xi32>, vector<8x8xi32> into vector<1x1x8xi32>
%29 = arith.extsi %16 : vector<1x1x8xi8> to vector<1x1x8xi32>
%30 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %29, %23, %21 : vector<1x1x8xi32>, vector<8x8xi32> into vector<1x1x8xi32>
%31 = vector.insert_strided_slice %24, %12 {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x1x8xi32> into vector<1x4x8xi32>
%32 = vector.insert_strided_slice %26, %31 {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<1x1x8xi32> into vector<1x4x8xi32>
%33 = vector.insert_strided_slice %28, %32 {offsets = [0, 2, 0], strides = [1, 1, 1]} : vector<1x1x8xi32> into vector<1x4x8xi32>
%34 = vector.insert_strided_slice %30, %33 {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x8xi32> into vector<1x4x8xi32>
vector.transfer_write %34, %buf4[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x4x8xi32>, memref<1x1x4x8xi32>

@jsetoain It looks like vector.extract_strided_slice is not supported in aievec.

makslevental commented 3 months ago

Or possibly supported in mlir-aie's aievec but not ours... @yzhang93 in the future feel free to tag me as well on aievec related issues since I "own" aievec in our/this plugin.

makslevental commented 3 months ago

greping for ExtractStridedSliceOp in mlir-aie gets no hits so I guess in fact it's not supported. We can discuss how to support (now/here/tomorrow/whenever).

yzhang93 commented 3 months ago

Or possibly supported in mlir-aie's aievec but not ours... @yzhang93 in the future feel free to tag me as well on aievec related issues since I "own" aievec in our/this plugin.

Okay, sure! One related question, do we still have any dependency on MLIR-AIE in third_party?

makslevental commented 3 months ago

Or possibly supported in mlir-aie's aievec but not ours... @yzhang93 in the future feel free to tag me as well on aievec related issues since I "own" aievec in our/this plugin.

Okay, sure! One related question, do we still have any dependency on MLIR-AIE in third_party?

Only op defs in tablegen https://github.com/nod-ai/iree-amd-aie/issues/430#issuecomment-2227174616

makslevental commented 3 months ago

@yzhang93 sorry being n00b - can you give me the iree-opt/iree-compile args/CLI to repro?

newling commented 3 months ago

I don't understand why there is a 7 here, can't the tile size be chosen larger (or smaller) to make it the size we want (4 or 8)? @yzhang93 @erwei-xilinx

erwei-xilinx commented 3 months ago

My understanding is that the issue is the need for vector.extract_strided_slice. Currently the input feature map data feeding into the cores are being laid out as <1xROWxInChan>, where ROW are consecutive. So the core is trying to do the layout transform below, in order to capture a 4x8 input for vector intrinsic.

o o o o ... o                     o o o o ... o
x x x x                           o o o o ... o
o o o o                           o o o o ... o
x x x x                   ->      o o o o ... o  
o o o o
x x x x
o o o o

Maybe we could have the DMA do the above, so that the input data is already in our expected shape before feeding into the kernel?

yzhang93 commented 3 months ago

I don't understand why there is a 7 here, can't the tile size be chosen larger (or smaller) to make it the size we want (4 or 8)? @yzhang93 @erwei-xilinx

This 7 comes from the input image width. Because we have the output width as multiple of 4 and this is a stride 2 Conv, the input width is the odd number.

jsetoain commented 3 months ago

I don't understand why there is a 7 here, can't the tile size be chosen larger (or smaller) to make it the size we want (4 or 8)? @yzhang93 @erwei-xilinx

This 7 comes from the input image width. Because we have the output width as multiple of 4 and this is a stride 2 Conv, the input width is the odd number.

Maybe I'm misunderstanding, but wouldn't a 8x8 slice do exactly the same? I understand that you wouldn't need the 8th <8xi8> one because you only actually use the 1st, 3rd, 5th, and 7th 64-bit sub-slice, but the hardware can only load 256-bit and 512-bit chunks of memory (and scalars, of course).

newling commented 3 months ago

Ok, I think this can be solved by changing the padding at some level of the lowering