iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.86k stars 622 forks source link

Llama 3.1 8B f16 decomposed fails to compile after commit 2602a2a730b5c1439b79360644c9884377110080 #19290

Open aviator19941 opened 21 hours ago

aviator19941 commented 21 hours ago

What happened?

Llama 3.1 8B f16 decomposed fails to compile after commit 2602a2a730b5c1439b79360644c9884377110080 with this error: https://gist.github.com/aviator19941/8fd308c01a36967e3e4815575ad6ea4e

Steps to reproduce your issue

  1. Build iree 2602a2a730b5c1439b79360644c9884377110080.
  2. Download IR: https://gist.github.com/aviator19941/9cf8da8bfc9da2ef5a47f87da6ce3045
  3. compile command: ../iree-build-no-trace/tools/iree-compile ../8b_f16_decomposed_11_22.mlir --iree-hip-target=gfx942 --iree-hal-target-backends=rocm -o=8b_f16_test.vmfb --iree-dispatch-creation-enable-aggressive-fusion=true --iree-global-opt-propagate-transposes=true --iree-opt-aggressively-propagate-transposes=true --iree-opt-data-tiling=false --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' --iree-stream-resource-memory-model=discrete --iree-hip-legacy-sync=false --iree-hal-indirect-command-buffers=true --iree-hal-memoization=true --iree-opt-strip-assertions
  4. See error here: https://gist.github.com/aviator19941/8fd308c01a36967e3e4815575ad6ea4e

What component(s) does this issue relate to?

Compiler

Version information

2602a2a730b5c1439b79360644c9884377110080

Additional context

No response

ScottTodd commented 21 hours ago

FYI, I can reproduce this error and a bisect (using the new scripting at https://github.com/iree-org/iree/pull/19289) pointed to the same culprit commit:

./bisect_releases.py \
  --good-ref=iree-3.0.0 \
  --bad-ref=upstream/main \
  --test-script=/home/nod/dev/data/iree-tmp/issue_19290.sh

2602a2a730b5c1439b79360644c9884377110080 is the first bad commit
commit 2602a2a730b5c1439b79360644c9884377110080
Author: Prashant Kumar <pk5561@gmail.com>
Date:   Fri Nov 22 18:08:43 2024 +0530

    [LLVMGPU] Use scf.forall for workgroup distribution (#18826)

    Enable scf.forall distribution for `tileAndBufferize`,
    `GPUWinogradVectorizePassPipeline`, `GPUMatmulSimtPassPipeline` ,
    `GPUTransposePassPipeline` and `GPUPackUnPackPasses` pipeline.
git bisect log
git bisect start '--no-checkout' '--first-parent'
# good: [29c451b00ecc9f9e5466e9d1079e0d69147da700] Yet more IREEGPUAttrs cleanup: drop `get{A,B,C}SingleSubgroupLayout` methods (#19169)
git bisect good 29c451b00ecc9f9e5466e9d1079e0d69147da700
# bad: [53e960146727759735815cac516683abb9bf5f86] Integrate llvm-project at fe3c23b439b9a2d00442d9bc6a4ca86f73066a3d (#19287)
git bisect bad 53e960146727759735815cac516683abb9bf5f86
# good: [41dcee93c7157955d94973addf6770cecf926849] Integrate LLVM at d7d6fb1804415b0f3e7f1cc9290bfb3d711cb707 (#19245)
git bisect good 41dcee93c7157955d94973addf6770cecf926849
# bad: [f55a5902db44bb6b728ff6f6b80d9fd10a1381d7] Update actions/cache version to latest release. (#19258)
git bisect bad f55a5902db44bb6b728ff6f6b80d9fd10a1381d7
# good: [4ee5d190cdfc977d7d5397db4fd14450e5effa9c] Add iree_codegen and iree_gpu dialects to Python readthedocs. (#19255)
git bisect good 4ee5d190cdfc977d7d5397db4fd14450e5effa9c
# good: [205af9200dc9c933fce06567ae141fba0424e537] [Global Opt] Turn on transpose propagation by default (#19253)
git bisect good 205af9200dc9c933fce06567ae141fba0424e537
# good: [e179a6e905a6682729aa816736f1004e84964ea4] [LLVMCPU] Migrate to TileRootAndFuseProducerConsumer pipeline (#19163)
git bisect good e179a6e905a6682729aa816736f1004e84964ea4
# bad: [2602a2a730b5c1439b79360644c9884377110080] [LLVMGPU] Use scf.forall for workgroup distribution (#18826)
git bisect bad 2602a2a730b5c1439b79360644c9884377110080
# first bad commit: [2602a2a730b5c1439b79360644c9884377110080] [LLVMGPU] Use scf.forall for workgroup distribution (#18826)
archana-ramalingam commented 17 hours ago

2602a2a throws a different error for bs>4 for the same llama model.

Command:

iree-compile llama3_decomposed.mlir \
  --iree-hip-target=gfx942 \
  --iree-hal-target-backends=rocm \
  -o=8b_f16_decomposed.vmfb \
  --iree-dispatch-creation-enable-aggressive-fusion=true \
  --iree-global-opt-propagate-transposes=true \
  --iree-opt-aggressively-propagate-transposes=true \
  --iree-opt-data-tiling=false \
  --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' \
  --iree-stream-resource-memory-model=discrete \
  --iree-hip-legacy-sync=false \
  --iree-hal-indirect-command-buffers=true \
  --iree-hal-memoization=true \
  --iree-opt-strip-assertions

Error:

failed to translate executables
/home/aramalin/shark-ai/tmp_perplexity_ci_artifacts/llama8b_f16_decomposed.mlir:1476:12: error: 'func.func' op uses -268353536 bytes of shared memory; exceeded the limit of 65536 bytes
    %497 = torch.aten.bmm %491, %496 : !torch.vtensor<[160,?,?],f16>, !torch.vtensor<[160,?,128],f16> -> !torch.vtensor<[160,?,128],f16>
           ^
/home/aramalin/shark-ai/tmp_perplexity_ci_artifacts/llama8b_f16_decomposed.mlir:1476:12: note: see current operation: 
"func.func"() <{function_type = () -> (), sym_name = "prefill_bs5$async_dispatch_18_batch_matmul_160xDx128xD_f16"}> ({
  %0 = "arith.constant"() <{value = 8 : index}> : () -> index

Full trace: https://gist.github.com/archana-ramalingam/67bb380c03e58da8893ba57c09108b08 IR (bs=5): https://gist.github.com/archana-ramalingam/1b8d1c5e89d1012603254027f29203fe