nod-ai / iree-amd-aie

IREE plugin repository for the AMD AIE accelerator
Apache License 2.0
62 stars 28 forks source link

Deeplabv3 Conv2d Shapes #559

Open yzhang93 opened 1 month ago

yzhang93 commented 1 month ago
  1. Stride 2 conv2d: %8 = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins(%3, %4, %c0_i32, %c0_i32 : tensor<1x515x515x3xi8>, tensor<3x3x3x32xi8>, i32, i32) outs(%7 : tensor<1x257x257x32xi32>) -> tensor<1x257x257x32xi32>

  2. Stride 1 conv2d filter 1x1: %8 = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4, %c0_i32, %c0_i32 : tensor<1x257x257x32xi8>, tensor<1x1x32x16xi8>, i32, i32) outs(%7 : tensor<1x257x257x16xi32>) -> tensor<1x257x257x16xi32> which can be converted to matmul_transpose_b: %8 = linalg.matmul_transpose_b ins(%3, %4 : tensor<66049x32xi8>, tensor<16x32xi8>) outs(%7 : tensor<66049x16xi32>) -> tensor<66049x16xi32>

Input Weight Output Stride Convert to matmul
1x515x515x3xi8 3x3x3x32xi8 1x257x257x32xi32 2 NA
1x257x257x32xi8 1x1x32x16xi8 1x257x257x16xi32 1 66049x16x32
1x257x257x16xi8 1x1x16x96xi8 1x257x257x96xi32 1 66049x96x16
1x129x129x96xi8 1x1x96x24xi8 1x129x129x24xi32 1 16641x24x96
1x129x129x24xi8 1x1x24x144xi8 1x129x129x144xi32 1 16641x144x24
1x129x129x144xi8 1x1x144x24xi8 1x129x129x24xi32 1 16641x24x144
1x65x65x144xi8 1x1x144x32xi8 1x65x65x32xi32 1 4225x32x144
1x65x65x32xi8 1x1x32x192xi8 1x65x65x192xi32 1 4225x192x32
1x65x65x192xi8 1x1x192x32xi8 1x65x65x32xi32 1 4225x32x192
1x65x65x192xi8 1x1x192x64xi8 1x65x65x64xi32 1 4225x64x192
1x65x65x64xi8 1x1x64x384xi8 1x65x65x384xi32 1 4225x384x64
1x65x65x384xi8 1x1x384x64xi8 1x65x65x64xi32 1 4225x64x384
1x65x65x384xi8 1x1x384x96xi8 1x65x65x96xi32 1 4225x96x384
1x65x65x96xi8 1x1x96x576xi8 1x65x65x576xi32 1 4225x576x96
1x65x65x576xi8 1x1x576x96xi8 1x65x65x96xi32 1 4225x96x576
1x65x65x576xi8 1x1x576x160xi8 1x65x65x160xi32 1 4225x160x576
1x65x65x160xi8 1x1x160x960xi8 1x65x65x960xi32 1 4225x960x160
1x65x65x960xi8 1x1x960x160xi8 1x65x65x160xi32 1 4225x160x960
1x65x65x960xi8 1x1x960x320xi8 1x65x65x320xi32 1 4225x320x960
1x65x65x320xi8 1x1x320x256xi8 1x65x65x256xi32 1 4225x256x320
1x1x1x320xi8 1x1x320x256xi8 1x1x1x256xi32 1 1x256x320
1x65x65x512xi8 1x1x512x256xi8 1x65x65x256xi32 1 4225x256x512
1x65x65x256xi8 1x1x256x21xi8 1x65x65x21xi32 1 4225x21x256
  1. Depthwise Conv2d:

%7 = linalg.conv_2d_ngchw_gfchw_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4, %c0_i32, %c0_i32 : tensor<1x32x1x259x259xi8>, tensor<32x1x1x3x3xi8>, i32, i32) outs(%6 : tensor<1x32x1x257x257xi32>) -> tensor<1x32x1x257x257xi32>

Input Weight Output Stride
1x32x1x259x259xi8 32x1x1x3x3xi8 1x32x1x257x257xi32 1
1x96x1x259x259xi8 96x1x1x3x3xi8 1x96x1x129x129xi32 2
1x144x1x131x131xi8 144x1x1x3x3xi8 1x144x1x129x129xi32 1
1x144x1x131x131xi8 144x1x1x3x3xi8 1x144x1x65x65xi32 2
1x192x1x67x67xi8 192x1x1x3x3xi8 1x192x1x65x65xi32 1
1x384x1x69x69xi8 384x1x1x3x3xi8 1x384x1x65x65xi32 1
1x576x1x69x69xi8 576x1x1x3x3xi8 1x576x1x65x65xi32 1
1x960x1x73x73xi8 960x1x1x3x3xi8 1x960x1x65x65xi32 1
yzhang93 commented 1 month ago

@newling @erwei-xilinx The above is a list of all the original conv shapes in the model without padding.

erwei-xilinx commented 1 month ago

Do they all have stride = 1?

yzhang93 commented 1 month ago

Do they all have stride = 1?

Good point. I've updated the table to include stride.

yzhang93 commented 1 month ago

@newling The depthwise ops didn't get transposed to channel last, because the pass only support linalg::Conv2DNchwFchwOp conversion. https://github.com/iree-org/iree/blob/4de493af31e370ca2eb1bb590469ebbf76fc8d5b/compiler/src/iree/compiler/Preprocessing/Common/ConvertConvToChannelsLast.cpp#L452

We have to extend the pass if we need to work on channel last version, otherwise we can directly try lowering for linalg.depthwise_conv_2d_nchw_chw or linalg.conv_2d_ngchw_gfchw_q.