iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.56k stars 574 forks source link

Llama-3-8B f16 fails to compile to vmfb #17226

Open aviator19941 opened 4 months ago

aviator19941 commented 4 months ago

batch_llama_3_8B.zip

What happened?

When trying to compile this mlir file, I get the shared memory error below:

failed to translate executables
failed to translate executables
failed to translate executables
result_llama_3_v4.mlir:352:7: error: 'func.func' op uses -46137344 bytes of shared memory; exceeded the limit of 65536 bytes
      func.func @prefill_bs4$async_dispatch_1_generic_4xDx4096_i64xf32(%arg0: !flow.dispatch.tensor<readonly:tensor<128256x4096xf16>>, %arg1: !flow.dispatch.tensor<readonly:tensor<4x?xi64>>, %arg2: index, %arg3: !flow.dispatch.tensor<writeonly:tensor<4x?x4096xf32>>) {
      ^
result_llama_3_v4.mlir:346:3: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>], target_arch = "gfx940", ukernels = "none"}>
  flow.executable private @prefill_bs4$async_dispatch_1 {
  ^
result_llama_3_v4.mlir:449:14: error: 'iree_linalg_ext.set_encoding' op unhandled tensor operation
        %7 = linalg.batch_matmul_transpose_b ins(%3, %4 : tensor<4x?x4096xf32>, tensor<4x4096x4096xf16>) outs(%6 : tensor<4x?x4096xf32>) -> tensor<4x?x4096xf32>
             ^
result_llama_3_v4.mlir:440:7: error: 'func.func' op failed to create tensor equivalance classes
      func.func @prefill_bs4$async_dispatch_4_batch_matmul_transpose_b_4xDx4096x4096_f32xf16xf32(%arg0: !flow.dispatch.tensor<readonly:tensor<4x?x4096xf32>>, %arg1: !flow.dispatch.tensor<readonly:tensor<4x4096x4096xf16>>, %arg2: index, %arg3: !flow.dispatch.tensor<writeonly:tensor<4x?x4096xf32>>) {
      ^
result_llama_3_v4.mlir:434:3: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>], target_arch = "gfx940", ukernels = "none"}>
  flow.executable private @prefill_bs4$async_dispatch_4 {
  ^
result_llama_3_v4.mlir:504:14: error: 'iree_linalg_ext.set_encoding' op unhandled tensor operation
        %7 = linalg.batch_matmul_transpose_b ins(%3, %4 : tensor<4x?x4096xf32>, tensor<4x1024x4096xf16>) outs(%6 : tensor<4x?x1024xf32>) -> tensor<4x?x1024xf32>
             ^
result_llama_3_v4.mlir:495:7: error: 'func.func' op failed to create tensor equivalance classes
      func.func @prefill_bs4$async_dispatch_7_batch_matmul_transpose_b_4xDx1024x4096_f32xf16xf32(%arg0: !flow.dispatch.tensor<readonly:tensor<4x?x4096xf32>>, %arg1: !flow.dispatch.tensor<readonly:tensor<4x1024x4096xf16>>, %arg2: index, %arg3: !flow.dispatch.tensor<writeonly:tensor<4x?x1024xf32>>) {
      ^
result_llama_3_v4.mlir:489:3: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>], target_arch = "gfx940", ukernels = "none"}>
  flow.executable private @prefill_bs4$async_dispatch_7 {
  ^

Steps to reproduce your issue

  1. Cherry pick iree#17182
  2. Cherry pick llvm-project#90141
  3. ../iree-build/tools/iree-compile --mlir-disable-threading --iree-opt-const-eval=false --compile-to=flow ../batch_llama_3_8B.mlir -o result_llama_3.mlir
  4. ../iree-build/tools/iree-compile --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx940 --iree-global-opt-propagate-transposes=true --iree-opt-const-eval=false --iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode result_llama_3.mlir -o llama_3.vmfb
  5. Error:
    failed to translate executables
    failed to translate executables
    failed to translate executables
    result_llama_3_v4.mlir:352:7: error: 'func.func' op uses -46137344 bytes of shared memory; exceeded the limit of 65536 bytes
      func.func @prefill_bs4$async_dispatch_1_generic_4xDx4096_i64xf32(%arg0: !flow.dispatch.tensor<readonly:tensor<128256x4096xf16>>, %arg1: !flow.dispatch.tensor<readonly:tensor<4x?xi64>>, %arg2: index, %arg3: !flow.dispatch.tensor<writeonly:tensor<4x?x4096xf32>>) {
      ^
    result_llama_3_v4.mlir:346:3: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>], target_arch = "gfx940", ukernels = "none"}>
    flow.executable private @prefill_bs4$async_dispatch_1 {
    ^
    result_llama_3_v4.mlir:449:14: error: 'iree_linalg_ext.set_encoding' op unhandled tensor operation
        %7 = linalg.batch_matmul_transpose_b ins(%3, %4 : tensor<4x?x4096xf32>, tensor<4x4096x4096xf16>) outs(%6 : tensor<4x?x4096xf32>) -> tensor<4x?x4096xf32>
             ^
    result_llama_3_v4.mlir:440:7: error: 'func.func' op failed to create tensor equivalance classes
      func.func @prefill_bs4$async_dispatch_4_batch_matmul_transpose_b_4xDx4096x4096_f32xf16xf32(%arg0: !flow.dispatch.tensor<readonly:tensor<4x?x4096xf32>>, %arg1: !flow.dispatch.tensor<readonly:tensor<4x4096x4096xf16>>, %arg2: index, %arg3: !flow.dispatch.tensor<writeonly:tensor<4x?x4096xf32>>) {
      ^
    result_llama_3_v4.mlir:434:3: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>], target_arch = "gfx940", ukernels = "none"}>
    flow.executable private @prefill_bs4$async_dispatch_4 {
    ^
    result_llama_3_v4.mlir:504:14: error: 'iree_linalg_ext.set_encoding' op unhandled tensor operation
        %7 = linalg.batch_matmul_transpose_b ins(%3, %4 : tensor<4x?x4096xf32>, tensor<4x1024x4096xf16>) outs(%6 : tensor<4x?x1024xf32>) -> tensor<4x?x1024xf32>
             ^
    result_llama_3_v4.mlir:495:7: error: 'func.func' op failed to create tensor equivalance classes
      func.func @prefill_bs4$async_dispatch_7_batch_matmul_transpose_b_4xDx1024x4096_f32xf16xf32(%arg0: !flow.dispatch.tensor<readonly:tensor<4x?x4096xf32>>, %arg1: !flow.dispatch.tensor<readonly:tensor<4x1024x4096xf16>>, %arg2: index, %arg3: !flow.dispatch.tensor<writeonly:tensor<4x?x1024xf32>>) {
      ^
    result_llama_3_v4.mlir:489:3: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {mma_intrinsics = [#iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>], target_arch = "gfx940", ukernels = "none"}>
    flow.executable private @prefill_bs4$async_dispatch_7 {
    ^

What component(s) does this issue relate to?

No response

Version information

f2746b464fb056ddadef4315654d59f727e4c9b0

Additional context

No response

benvanik commented 4 months ago

lol I'm guessing something is multiplying by a dynamic dimension (sentinel -1) without checking :P

benvanik commented 4 months ago

(to reproduce we'll need the batch_llama_3_8B.mlir file, or the entire contents of the @prefill_bs4$async_dispatch_1_generic_4xDx4096_i64xf32 flow.executable/hal.executable op prior to the error )

aviator19941 commented 4 months ago

Yeah I'll upload it here, accidentally submitted the issue before uploading it here :)

aviator19941 commented 4 months ago

@benvanik I uploaded a zip that has the batch_llama_3_8B.mlir file

hanhanW commented 4 months ago

It looks like it failed in SetEncoding (or related passes). @pashu123 given that you want to get more involved in these tasks, would you like to triage the issue when you're available?

pashu123 commented 4 months ago

@aviator19941 Do we need to cherry-pick some commit or checkout branch? On main branch I am noticing this

batch_llama_3_8B.mlir:1003:12: error: 'flow.tensor.reshape' op operand #2 must be variadic of index, but got 'i64'
    %339 = torch.aten.view %333, %338 : !torch.vtensor<[4,?,32,128],f32>, !torch.list<int> -> !torch.vtensor<[4,?,32,64,2],f32>
           ^
batch_llama_3_8B.mlir:1003:12: note: see current operation: %352 = "flow.tensor.reshape"(%331, %305, %351) <{operandSegmentSizes = array<i32: 1, 1, 1>}> : (tensor<4x?x4096xf32>, index, i64) -> tensor<4x?x32x64x2xf32>
pashu123 commented 4 months ago

We need to cherry-pick this https://github.com/iree-org/iree/pull/17182 for the 1st command to work.

pashu123 commented 4 months ago

Here's the minimal repro https://gist.github.com/pashu123/45fe64caa21cfdfa9890698660184a44

pashu123 commented 4 months ago

This is failing in the // -----// IR Dump After GPUCheckResourceUsage Failed (iree-codegen-gpu-check-resource-usage) //----- //

pashu123 commented 4 months ago

@aviator19941 The failure is due to https://gist.github.com/pashu123/020217a35f1c643ed03b169ce41f68d9 (embedding kernel). It has a cast from fp16 -> fp32. Please double-check that it's a full fp16 model. Also, could you post how to obtain the IRs?

pashu123 commented 4 months ago

A possible optimization that can be thought of is

 %0 = torch.prims.convert_element_type %arg1, %int6 : !torch.vtensor<[128256,4096],f16>, !torch.int -> !torch.vtensor<[128256,4096],f32>
 %1 = torch.aten.embedding %0, %arg0, %int-1, %false_0, %false : !torch.vtensor<[128256,4096],f32>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f32>
    return %1 : !torch.vtensor<[4,?,4096],f32>

can be replaced by

 %0 = torch.aten.embedding %arg1, %arg0, %int-1, %false_0, %false : !torch.vtensor<[128256,4096],f16>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f16>
    return %1 : !torch.vtensor<[4,?,4096],f16>
 %1 = torch.prims.convert_element_type %0, %int6 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32>

i.e., we don't need to cast the entire embedding matrix; we just cast what we want out of the matrix. The above repro takes forever to compile on the CPU backend. However, when we apply the optimization, we don't get the error: func.func op uses -46137344 bytes of shared memory; exceeded the limit of 65536 bytes.

hanhanW commented 4 months ago

It is not compiled because vector.gather is lowered to a lot of vector ops -- which should be fixed.

The other issue is that we are having two generic ops and they are not fused in TileAndFuse. Because there are no operands dependency between two generic ops. It should be fixed before sending it to codegen. I don't have a good solution so far. Perhaps we should just disable the fusion for this kind of case. @MaheshRavishankar do you have any suggestions?

func.func @decode_bs4$async_dispatch_0_generic_4xDx4096_i64xf32() {
  %c0 = arith.constant 0 : index
  %c32_i64 = arith.constant 32 : i64
  %0 = hal.interface.constant.load[0] : i32
  %1 = hal.interface.constant.load[1] : i32
  %2 = arith.extui %0 : i32 to i64
  %3 = arith.extui %1 : i32 to i64
  %4 = arith.shli %3, %c32_i64 : i64
  %5 = arith.ori %2, %4 : i64
  %6 = arith.index_castui %5 : i64 to index
  %7 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<128256x4096xf16>>
  %8 = flow.dispatch.workload.ordinal %6, 0 : index
  %9 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4x?xi64>>{%8}
  %10 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4x?x4096xf32>>{%8}
  %11 = flow.dispatch.tensor.load %7, offsets = [0, 0], sizes = [128256, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<128256x4096xf16>> -> tensor<128256x4096xf16>
  %12 = flow.dispatch.tensor.load %9, offsets = [0, 0], sizes = [4, %8], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4x?xi64>>{%8} -> tensor<4x?xi64>
  %13 = tensor.empty(%8) : tensor<4x?x4096xf32>
  %14 = tensor.empty() : tensor<128256x4096xf32>
  %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%11 : tensor<128256x4096xf16>) outs(%14 : tensor<128256x4096xf32>) {
  ^bb0(%in: f16, %out: f32):
    %17 = arith.extf %in : f16 to f32
    linalg.yield %17 : f32
  } -> tensor<128256x4096xf32>
  %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12 : tensor<4x?xi64>) outs(%13 : tensor<4x?x4096xf32>) {
  ^bb0(%in: i64, %out: f32):
    %17 = arith.index_cast %in : i64 to index
    %18 = linalg.index 2 : index
    %extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32>
    linalg.yield %extracted : f32
  } -> tensor<4x?x4096xf32>
  flow.dispatch.tensor.store %16, %10, offsets = [0, 0, 0], sizes = [4, %8, 4096], strides = [1, 1, 1] : tensor<4x?x4096xf32> -> !flow.dispatch.tensor<writeonly:tensor<4x?x4096xf32>>{%8}
  return
}
aviator19941 commented 4 months ago

@aviator19941 The failure is due to https://gist.github.com/pashu123/020217a35f1c643ed03b169ce41f68d9 (embedding kernel). It has a cast from fp16 -> fp32. Please double-check that it's a full fp16 model. Also, could you post how to obtain the IRs?

When I try to set the activation and attention dtypes to fp16 here, I run into

convertScalarToDtype should handle all the types
UNREACHABLE executed at iree/third_party/torch-mlir/lib/Conversion/Utils/Utils.cpp:355!

because it is trying to multiply complex<f16> and complex<f32> (repro). So I think it has to do with some dtype in the model that should be fp16, but is not.

aviator19941 commented 4 months ago

@aviator19941 The failure is due to https://gist.github.com/pashu123/020217a35f1c643ed03b169ce41f68d9 (embedding kernel). It has a cast from fp16 -> fp32. Please double-check that it's a full fp16 model. Also, could you post how to obtain the IRs?

In order to obtain the IR's:

  1. set up sharktank - sharktank
  2. rebase/checkout enable_llama3 branch
  3. clone and build llama.cpp - llama.cpp
  4. run export_paged_llm_v1 example - llama3 IR
pashu123 commented 4 months ago

@aviator19941 The failure is due to https://gist.github.com/pashu123/020217a35f1c643ed03b169ce41f68d9 (embedding kernel). It has a cast from fp16 -> fp32. Please double-check that it's a full fp16 model. Also, could you post how to obtain the IRs?

When I try to set the activation and attention dtypes to fp16 here, I run into

convertScalarToDtype should handle all the types
UNREACHABLE executed at iree/third_party/torch-mlir/lib/Conversion/Utils/Utils.cpp:355!

because it is trying to multiply complex<f16> and complex<f32> (repro). So I think it has to do with some dtype in the model that should be fp16, but is not.

I think I can add the fix for this. It is required to enable the full Fp16 precision model.

pashu123 commented 4 months ago

@aviator19941 You can get the latest fp16 IR from wget https://huggingface.co/prashantk/test_files/resolve/main/batch_llama_v1.mlir?download=true.

It's able to generate the .vmfb setting llvm-cpu backend with the command iree-compile -iree-input-type=torch --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features=host batch_llama_v1.mlir -iree-opt-demote-i64-to-i32 -o llama3.vmfb

pashu123 commented 4 months ago

You need to cherry-pick https://github.com/iree-org/iree/pull/17247

hanhanW commented 4 months ago

I think there are still action items in the issue, the look-up table fusion is scaring me. We should fix that at least. The tile sizes for vector.gather are problematic. They will be fully unrolled, which looks really bad.

pashu123 commented 4 months ago

I think there are still action items in the issue, the look-up table fusion is scaring me. We should fix that at least. The tile sizes for vector.gather are problematic. They will be fully unrolled, which looks really bad.

I never intended to close the issue; I don't know if it got closed automatically. Yes, for the mixed precision case in which we have activations represented as f32, we still have action items to do.

hanhanW commented 4 months ago

Confirmed that the fusion is not expected. @MaheshRavishankar will fix it.

For the gather codegen issue, @pashu123 could you create a input case for the generic op and see what's happening? I'm expecting that some dimensions would be collapsed, and the next issue could be tile size selection. https://github.com/iree-org/iree/pull/17227 could help, but there could other issues remaining on the table.

  %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12 : tensor<4x?xi64>) outs(%13 : tensor<4x?x4096xf32>) {
  ^bb0(%in: i64, %out: f32):
    %17 = arith.index_cast %in : i64 to index
    %18 = linalg.index 2 : index
    %extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32>
    linalg.yield %extracted : f32
  } -> tensor<4x?x4096xf32>
MaheshRavishankar commented 4 months ago

Confirmed that the fusion is not expected. @MaheshRavishankar will fix it.

For the gather codegen issue, @pashu123 could you create a input case for the generic op and see what's happening? I'm expecting that some dimensions would be collapsed, and the next issue could be tile size selection. #17227 could help, but there could other issues remaining on the table.

  %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12 : tensor<4x?xi64>) outs(%13 : tensor<4x?x4096xf32>) {
  ^bb0(%in: i64, %out: f32):
    %17 = arith.index_cast %in : i64 to index
    %18 = linalg.index 2 : index
    %extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32>
    linalg.yield %extracted : f32
  } -> tensor<4x?x4096xf32>

Well, fusion is not expected cause I wasnt looking at it properly. It is expected and I think it is probably what you want at the dispatch level. If we dont fuse this we will materialize a tensor of size 128256x4096x4 bytes which is completely unnecessary.

The real issue though is that the op shouldnt be lowered this way. A better representation of this would be to do

%8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2 : tensor<4x?xi64>) outs(%3 : tensor<4x?x4096xf32>) {
    ^bb0(%in: i64, %out: f32):
      %9 = arith.index_cast %in : i64 to index
      %10 = linalg.index 2 : index
      %extracted = tensor.extract %5[%9, %10] : tensor<128256x4096xf16>
      %extracted_f32 = arith.extf %extracted : f16 to f32
      linalg.yield %extracted_f32 : f32
    } -> tensor<4x?x4096xf32>

That should fix one of the issue Hanhan mentioned. If we can fix the front end to do this that would be best. If not, then we should just write an ad-hoc pattern that does this kind of fusion. There is really nothing structured about this to generalize here. This is just a specific pattern which is just a WAR to a front-end lowering issue.

pashu123 commented 4 months ago

A possible optimization that can be thought of is

 %0 = torch.prims.convert_element_type %arg1, %int6 : !torch.vtensor<[128256,4096],f16>, !torch.int -> !torch.vtensor<[128256,4096],f32>
 %1 = torch.aten.embedding %0, %arg0, %int-1, %false_0, %false : !torch.vtensor<[128256,4096],f32>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f32>
    return %1 : !torch.vtensor<[4,?,4096],f32>

can be replaced by

 %0 = torch.aten.embedding %arg1, %arg0, %int-1, %false_0, %false : !torch.vtensor<[128256,4096],f16>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f16>
    return %1 : !torch.vtensor<[4,?,4096],f16>
 %1 = torch.prims.convert_element_type %0, %int6 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32>

i.e., we don't need to cast the entire embedding matrix; we just cast what we want out of the matrix. The above repro takes forever to compile on the CPU backend. However, when we apply the optimization, we don't get the error: func.func op uses -46137344 bytes of shared memory; exceeded the limit of 65536 bytes.

@MaheshRavishankar, does this sound reasonable to add to torch-mlir?

pashu123 commented 4 months ago

A possible optimization that can be thought of is

 %0 = torch.prims.convert_element_type %arg1, %int6 : !torch.vtensor<[128256,4096],f16>, !torch.int -> !torch.vtensor<[128256,4096],f32>
 %1 = torch.aten.embedding %0, %arg0, %int-1, %false_0, %false : !torch.vtensor<[128256,4096],f32>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f32>
    return %1 : !torch.vtensor<[4,?,4096],f32>

can be replaced by

 %0 = torch.aten.embedding %arg1, %arg0, %int-1, %false_0, %false : !torch.vtensor<[128256,4096],f16>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f16>
    return %1 : !torch.vtensor<[4,?,4096],f16>
 %1 = torch.prims.convert_element_type %0, %int6 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32>

i.e., we don't need to cast the entire embedding matrix; we just cast what we want out of the matrix. The above repro takes forever to compile on the CPU backend. However, when we apply the optimization, we don't get the error: func.func op uses -46137344 bytes of shared memory; exceeded the limit of 65536 bytes.

@MaheshRavishankar, does this sound reasonable to add to torch-mlir?

Added here: https://github.com/llvm/torch-mlir/pull/3277

stellaraccident commented 4 months ago

Not sure why this keeps closing

stellaraccident commented 4 months ago

A possible optimization that can be thought of is

 %0 = torch.prims.convert_element_type %arg1, %int6 : !torch.vtensor<[128256,4096],f16>, !torch.int -> !torch.vtensor<[128256,4096],f32>
 %1 = torch.aten.embedding %0, %arg0, %int-1, %false_0, %false : !torch.vtensor<[128256,4096],f32>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f32>
    return %1 : !torch.vtensor<[4,?,4096],f32>

can be replaced by

 %0 = torch.aten.embedding %arg1, %arg0, %int-1, %false_0, %false : !torch.vtensor<[128256,4096],f16>, !torch.vtensor<[4,?],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,?,4096],f16>
    return %1 : !torch.vtensor<[4,?,4096],f16>
 %1 = torch.prims.convert_element_type %0, %int6 : !torch.vtensor<[4,?,4096],f16>, !torch.int -> !torch.vtensor<[4,?,4096],f32>

i.e., we don't need to cast the entire embedding matrix; we just cast what we want out of the matrix. The above repro takes forever to compile on the CPU backend. However, when we apply the optimization, we don't get the error: func.func op uses -46137344 bytes of shared memory; exceeded the limit of 65536 bytes.

@MaheshRavishankar, does this sound reasonable to add to torch-mlir?

FYI, if you can make the torch embedding lookup good, that is best. But also I carved this out for a potential special op: it would be trivial to write a custom op at the frontend that expanded to whatever linalg you want.

benvanik commented 4 months ago

Not sure why this keeps closing

@pashu123 put a "fixes" command in a commit message and now anyone who has write access to the repo will close it when they merge in that commit to their forks of whatever :P https://github.com/aartbik/torch-mlir/commit/8c48135a426b84fa412b031fc92e12826ff60b31

stellaraccident commented 4 months ago

With that said, moving the cast across the embedding lookup is a common optimization.

I'm a bit worried that the default path on this generates basically unusable code, though.

MaheshRavishankar commented 4 months ago

With that said, moving the cast across the embedding lookup is a common optimization.

I'm a bit worried that the default path on this generates basically unusable code, though.

That's fair, but we just don't represent gathers well. And if we clone the quantization into all its used dispatches (as we do now under current understanding of best way to handle dequantization) none of the transformations can actually fuse and generate this code. The producer consumer dependency only materializes from within the body of the consumer. Nothing accounts for that and it just falls off the cliff

qedawkins commented 4 months ago

Not sure why this keeps closing

@pashu123 put a "fixes" command in a commit message and now anyone who has write access to the repo will close it when they merge in that commit to their forks of whatever :P aartbik/torch-mlir@8c48135

Why is Github unable to prevent actions on forks from spamming main repos... Seems like a big anti-feature.

pashu123 commented 4 months ago

@aviator19941 Do you have instructions on how to run llama3 for the IREE backend?

MaheshRavishankar commented 4 months ago

With that said, moving the cast across the embedding lookup is a common optimization.

I'm a bit worried that the default path on this generates basically unusable code, though.

More I think about this, it might be worth just doing the fusion of

 %15 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%11 : tensor<128256x4096xf16>) outs(%14 : tensor<128256x4096xf32>) {
  ^bb0(%in: f16, %out: f32):
    %17 = arith.extf %in : f16 to f32
    linalg.yield %17 : f32
  } -> tensor<128256x4096xf32>
  %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%12 : tensor<4x?xi64>) outs(%13 : tensor<4x?x4096xf32>) {
  ^bb0(%in: i64, %out: f32):
    %17 = arith.index_cast %in : i64 to index
    %18 = linalg.index 2 : index
    %extracted = tensor.extract %15[%17, %18] : tensor<128256x4096xf32>
    linalg.yield %extracted : f32
  } -> tensor<4x?x4096xf32>

to

%8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%2 : tensor<4x?xi64>) outs(%3 : tensor<4x?x4096xf32>) {
    ^bb0(%in: i64, %out: f32):
      %9 = arith.index_cast %in : i64 to index
      %10 = linalg.index 2 : index
      %extracted = tensor.extract %5[%9, %10] : tensor<128256x4096xf16>
      %extracted_f32 = arith.extf %extracted : f16 to f32
      linalg.yield %extracted_f32 : f32
    } -> tensor<4x?x4096xf32>

as a one-off canonicalization for now to not fall off a cliff. Might be hard to make it future proof, but more examples will help. @IanWood1 just FYI for something for us to discuss (and for you to pick up as a simple task). Please make sure we chat about this next time we sync.

benvanik commented 4 months ago

Agreed at handling even if not generalized as it's pretty catastrophic to clone embeddings.

I think the more durable fix may be proper propagation: we should sink any exts down/hoist truncs up across memcpy-like ops (such as this gather or a scatter). We may with the current logic be in a better situation but still want to ensure we don't materialize ext/trunc dispatches unless absolutely required.

zjgarvey commented 4 months ago

Noting that this issue also occurs with some other models. In the SHARK-TestSuite, the onnx/models/RAFT_vaiq_int8 also encounters a similar issue. To reproduce, set up the test suite, and run

python run.py --cachedir=/path/to/.cache/ -t onnx/models/RAFT_vaiq_int8/ -m onnx -c /path/to/torch-mlir/build/ -i /path/to/iree-build/ --torchtolinalg

with an up-to-date torch-mlir and iree build.