nod-ai / iree-amd-aie

IREE plugin repository for the AMD AIE accelerator
Apache License 2.0
64 stars 29 forks source link

Understanding ESRGAN model through IREE #403

Closed newling closed 3 months ago

newling commented 3 months ago

I downloaded esrgan_fp32_linalg.mlir provided by @vivekkhandelwal1 in a teams channel, and ran

mkdir dispatches
/path/to/iree-compile --iree-hal-dump-executable-sources-to=./dispatches  \
                      --iree-hal-target-backends=llvm-cpu  \
                        esrgan_fp32_linalg.mlir  -o cpu_exec.vmfb

This dumps the dispatches as individual mlir files in the directory "dispatches".

How many dispatches are there?

find . -name "*.mlir" | wc -l

There are 360 distinct dispatches. This is more than expected, as there are not that many distinct layers in the mode. But some of these dispatches are identical except for weights. For example when I compare 2 dispatches as follows:

vimdiff module_torch-jit-export_dispatch_165.mlir module_torch-jit-export_dispatch_195.mlir

I see that the only difference is in the weights that are embedded in the dispatch. Is there a way to make the weights to function operands, so that the functions are unique?

$ diff module_torch-jit-export_dispatch_165.mlir module_torch-jit-export_dispatch_195.mlir
1c1
< hal.executable public @"torch-jit-export_dispatch_165" {
---
> hal.executable public @"torch-jit-export_dispatch_195" {
3c3
<     hal.executable.export public @"torch-jit-export_dispatch_165_conv_2d_nchw_fchw_1x32x250x250x96x3x3_f32" ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]} {
---
>     hal.executable.export public @"torch-jit-export_dispatch_195_conv_2d_nchw_fchw_1x32x250x250x96x3x3_f32" ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]} {
9c9
<       func.func @"torch-jit-export_dispatch_165_conv_2d_nchw_fchw_1x32x250x250x96x3x3_f32"() {
---
>       func.func @"torch-jit-export_dispatch_195_conv_2d_nchw_fchw_1x32x250x250x96x3x3_f32"() {
12c12
<         %cst_1 = arith.constant dense<[[-0.129037291, -1.422290e-01, -0.0807310119, -0.128244475, -0.0398165733, -0.119981401, -0.113069206, -0.126817971, -0.0444165319, -0.0474722497, -0.0828626826, -0.0859258696, -0.145084366, -0.15676716, -0.0708842278, -0.149930134, -0.0774273946, -0.0647515133, -0.0672570392, -0.165924489, -0.0497073941, -7.261850e-02, -0.0830464735, -0.0497843362, -0.0845366641, -0.106523126, -0.0766393915, -0.0538191646, -0.115074076, -0.0525894538, -0.340261906, -0.155870825]]> : tensor<1x32xf32>
---
>         %cst_1 = arith.constant dense<[[-0.0380056612, -0.0330086388, -0.0296086408, -0.0663379803, -0.121410556, -0.0711337253, -0.111735687, -0.0797394365, -0.0937284901, -0.094646424, -0.0800017416, -0.0950915217, -0.144881442, -0.0933115109, -0.0424141847, -0.0521149226, -0.0428891256, -0.126851648, -0.048052825, -0.0963224172, -0.0750880837, -0.027913332, -0.0797173604, -0.0918321759, -0.0825948789, -0.169388652, -0.17149733, -0.0586047694, -0.0869528278, -0.10293851, -5.990920e-02, -0.11020527]]> : tensor<1x32xf32>
14c14
<         %c51211008 = arith.constant 51211008 : index
---
>         %c48335616 = arith.constant 48335616 : index
17c17
<         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c51211008) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x96x3x3xf32>>
---
>         %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c48335616) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32x96x3x3xf32>>
newling commented 3 months ago

Adding a flag suggested by @MaheshRavishankar to avoid weight inlining works excellently to reduce the number of dispatches, so now running the following compilation:

iree-compile --iree-flow-inline-constants-max-byte-length=0   \
                     --iree-hal-dump-executable-sources-to=./dispatches \
                     --iree-hal-target-backends=llvm-cpu \
                        esrgan_fp32_linalg.mlir \  
                     -o cpu_exec.vmfb

we have just 20 dispatches. 7 of these don't contain any linalg operations, they're just flow.dispatch.load and flow.dispatch.store ops. Of the remaining 13 dispatches, 12 contain linalg.conv_2d_nchw_fchw, and the other one is just a linalg.generic with all "parallel" dimensions.

I've stitched the dispatches back together in the following gist: https://gist.github.com/newling/6828571a28ac2a2c33cd35c355ac903b