nod-ai / iree-amd-aie

IREE plugin repository for the AMD AIE accelerator
Apache License 2.0
68 stars 29 forks source link

Gemma's GEMM tracker #192

Open Abhishek-Varma opened 8 months ago

Abhishek-Varma commented 8 months ago

I'll beautify this once I get hold of Azure storage.

I have attached gemma_7b.mlir along with gemma weights.

For now, I've uploaded all GEMM dispatches here.

GEMMs in Gemma model appear in two forms :-

  1. linalg.batch_matmul.
  2. linalg.matmul_transpose_b.
Dispatch Type Type Shapes Running on AIE
linalg.batchmatmul f32 16x1x256xD No
linalg.batchmatmul f32 16x1xDx256 No
linalg.batchmatmul f32 16xDx256xD No
linalg.batchmatmul f32 16xDxDx256 No
linalg.batchmatmul f32 1x128x1x1 No
linalg.batchmatmul f32 1x128xDx1 No
linalg.matmul_transpose_b f32 1x256000x3072 No
linalg.matmul_transpose_b f32 Dx256000x3072 No

NOTE: I first tried getting the Gemma model compiled for llvm-cpu and only found batch_mmt4d - I've added those here.

iree-compile gemma_7b.mlir --iree-input-type=torch \
                --iree-hal-target-backends=llvm-cpu \
                --iree-hal-dump-executable-sources-to=GEMMA_DISPATCHES \
                -o test.vmfb

And when I tried compiling Gemma model for amd-aie backend, I found the above dispatches.

iree-compile gemma_7b.mlir --iree-input-type=torch \
                --iree-hal-target-backends=amd-aie \
                --iree-hal-dump-executable-sources-to=GEMMA_DISPATCHES \
                -o test.vmfb
Abhishek-Varma commented 8 months ago

Since linalg.matmul_transpose_b's initial support is in, I tried the shape 1x256000x3072 (for i32 though) and here is the IR e2e log.

The error is in iree-amdaie-decompose-pack-unpack-to-air pass :-

error: 'memref.expand_shape' op collapsed dim size (1) must equal reassociation group size (4)
        %7 = linalg.matmul_transpose_b ins(%3, %4 : tensor<1x3072xi32>, tensor<256000x3072xi32>) outs(%6 : tensor<1x256000xi32>) -> tensor<1x256000xi32>
       ^
note: see current operation: %25 = "memref.expand_shape"(%23) <{reassociation = [[0], [1], [2, 3], [4, 5]]}> : (memref<1x1x1x32xi32, strided<[3072, 3072, 3072, 1], offset: ?>, 1 : i32>) -> memref<1x1x1x4x4x8xi32, strided<[3072, 3072, 12288, 3072, 8, 1], offset: ?>, 1 : i32>
error: failed to run translation of source executable to target executable for backend #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_arch = "chip-tbd"}>

CC: @MaheshRavishankar @nirvedhmeshram @yzhang93 @erwei-xilinx