nod-ai / sharktank

SHARK Inference Modeling and Serving
Apache License 2.0
9 stars 9 forks source link

Tracking running llama models through IREE #22

Open ScottTodd opened 3 months ago

ScottTodd commented 3 months ago

Goal

Run a llama model from https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/models/llama/llama.py through IREE

Starting with open_llama_3b_v2_f16_gguf since we have that in docs. Could try another model or data type but should eventually all sorts of variants working.

Approach

https://github.com/nod-ai/sharktank/tree/main/sharktank/sharktank/examples has a few files already:

file description
paged_llm_v1.py Run LLM (from GGUF or hyperparameter config + parameter weights) in PyTorch
export_paged_llm_v1.py Export LLM to a .mlir file for IREE

Next steps from there could be

  1. Compile the .mlir file using iree-compile and run it using iree-run-module
  2. Add an IREE version of paged_llm_v1.py that could either
    • Export (e.g. from GGUF) -> compile -> run, all in-process
    • Compile from .mlir -> run
    • Take an already compiled .vmfb and run it

Worklog

Export -> try compile entire program ("prefill" and "decode")

Next: continue triaging compilation errors for prefill.

Export and run just "decode"

Next: try upgrading GGUF version 2 to 3? Load from safetensors? Convert to IRPA?

pashu123 commented 3 months ago

For spriv-vulkan backend here's the minimal repro

 func.func @torch_add(%arg0: !torch.vtensor<[1,1,?,?],i1>, %arg1: !torch.vtensor<[4,1,1,?],i1>) -> !torch.vtensor<[4, 1, ?, ?],i1> {
    %int1 = torch.constant.int 1
    %2 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[1,1,?,?],i1>, !torch.vtensor<[4,1,1,?],i1>, !torch.int -> !torch.vtensor<[4,1,?,?],i1>
    return %2 : !torch.vtensor<[4,1,?,?],i1>
  }

error: spirv.IAdd op operand #0 must be 8/16/32/64-bit integer but got i1 .

pashu123 commented 3 months ago

Pulling in some of the comments from the chat. For the CPU backend there are two options:

  1. Use the iree-opt-demote-i64-to-32 flag; these models deal with large no. and truncating might not be the good strategy.
  2. Use the --iree-opt-strip-assertions flag; assertions hanging around, strips them and compiles the model.

For spirv-vulkan backend I have posted the minimal repro above.

stellaraccident commented 3 months ago

I think that assert can be safely dropped at the torch level in the same way as the broadcast asserts: when in strict mode from torch, the invariant being checked for dynamic legality must be true (torch enforces it).

ScottTodd commented 3 months ago

Thanks, I'm also able to compile for llvm-cpu with --iree-opt-strip-assertions. edit: specifically with https://github.com/llvm/torch-mlir/pull/3277 too

ScottTodd commented 3 months ago

Compilation correctness

GGUF version 2 vs version 3

Running just decode, with zeroed arguments:

Vulkan:

iree-run-module --module=/tmp/open_llama_3b_v2_f16_vulkan_decode_only.vmfb --device=vulkan --input=4x1xi64 --input=4xi64 --input=4xi64 --input=4x1xi64 --input=1x2662400xf32 --parameters=model=/tmp/huggingface/open_llama_3b_v2_gguf/open-llama-3b-v2-f16.gguf

EXEC @decode_bs4
D:\dev\projects\iree\runtime\src\iree\hal\command_buffer_validation.c:363: INVALID_ARGUMENT; source and target ranges overlap within the same buffer; stack:
  0x00007ff6e1f7238f iree-run-module <iree_hal_command_buffer_copy_buffer_validation+0x23f> (D:\dev\projects\iree\runtime\src\iree\hal\command_buffer_validation.c:361)
  0x00007ff6e1f67e18 iree-run-module <iree_hal_command_buffer_copy_buffer+0xa8> (D:\dev\projects\iree\runtime\src\iree\hal\command_buffer.c:458)
  0x00007ff6e1f00a72 iree-run-module <iree_hal_module_command_buffer_copy_buffer+0xc2> (D:\dev\projects\iree\runtime\src\iree\modules\hal\module.c:798)
  0x00007ff6e1f16642 iree-run-module <iree_vm_shim_rrIrII_v+0x82> (D:\dev\projects\iree\runtime\src\iree\vm\shims.c:65)
  0x00007ff6e1f19754 iree-run-module <iree_vm_native_module_issue_call+0x84> (D:\dev\projects\iree\runtime\src\iree\vm\native_module.c:342)

CPU (local-task): assert hit, --trace-execution output: https://gist.github.com/ScottTodd/8c215d943f6f27fa480a8ba5ed328cb3

iree-run-module --module=/tmp/open_llama_3b_v2_f16_cpu.vmfb --device=local-sync --input=4x1xi64 --input=4xi64 --input=4xi64 --input=4x1xi64 --input=1x2662400xf32 --parameters=model=/tmp/huggingface/open_llama_3b_v2_gguf/open-llama-3b-v2-f16.gguf --function=decode_bs4 --trace-execution

...
[module.decode_bs4$async+000410C2]    %r0 = vm.call @hal.command_buffer.create(%r266(!hal.device/0x0000015AD90B22D0), %i206(1), %i206(1), %i83(0))
[module.decode_bs4$async+000410D6]    vm.call @hal.command_buffer.copy_buffer(%r0(!hal.command_buffer/0x0000015CAECF9450), %r4(!hal.buffer/0x0000015C7B094080), %i84(0), %r4(!hal.buffer/0x0000015C7B094080), %i84(0), %i100(10649600))

--- assert hit ---
ucrtbase.dll!00007ff8674d286e() (Unknown Source:0)
iree-run-module.exe!iree_abort() Line 26 (d:\dev\projects\iree\runtime\src\iree\base\assert.h:26)
iree-run-module.exe!iree_vm_buffer_deinitialize(iree_vm_buffer_t * buffer) Line 79 (d:\dev\projects\iree\runtime\src\iree\vm\buffer.c:79)
iree-run-module.exe!iree_vm_bytecode_module_destroy(void * self) Line 152 (d:\dev\projects\iree\runtime\src\iree\vm\bytecode\module.c:152)
iree-run-module.exe!iree_vm_context_release_modules(iree_vm_context_t * context, unsigned __int64 start, unsigned __int64 end) Line 288 (d:\dev\projects\iree\runtime\src\iree\vm\context.c:288)
iree-run-module.exe!iree_vm_context_destroy(iree_vm_context_t * context) Line 362 (d:\dev\projects\iree\runtime\src\iree\vm\context.c:362)
iree-run-module.exe!iree_tooling_run_module_with_data(iree_vm_instance_t * instance, iree_string_view_t default_device_uri, iree_const_byte_span_t module_contents, iree_allocator_t host_allocator, int * out_exit_code) Line 422 (d:\dev\projects\iree\runtime\src\iree\tooling\run_module.c:422)

Next: figure out the runtime errors. Miscompile? Going over some runtime limits? local-sync and local-task have different errors. Look at the VM IR and see if anything stands out.

ScottTodd commented 3 months ago

I created a mock version of open_llama_3b_v2_f16.mlir here: https://gist.github.com/ScottTodd/ee0cd9d6ab80e4814edad353235cf664. That just returns 1 for all values (no math/kernels/etc.).

Compile with:

iree-compile \
  mock_open_llama_3b_v2_f16.mlir \
  --iree-hal-target-backends=llvm-cpu \
  -o mock_open_llama_3b_v2_f16_cpu.vmfb

Run prefill with:

iree-run-module \
  --module=mock_open_llama_3b_v2_f16_cpu.vmfb \
  --device=local-sync \
  --function=decode_bs4 \
  --input=4x1xi64 \
  --input=4xi64 \
  --input=4x1xi64 \
  --input=1x2662400xf32 \
  --parameters=model=open-llama-3b-v2-f16.gguf

Run decode with:

iree-run-module \
  --module=mock_open_llama_3b_v2_f16_cpu.vmfb \
  --device=local-sync \
  --function=decode_bs4 \
  --input=4x1xi64 \
  --input=4xi64 \
  --input=4xi64 \
  --input=4x1xi64 \
  --input=1x2662400xf32 \
  --parameters=model=open-llama-3b-v2-f16.gguf

I'm planning on loading that into Python and standing up an IREE version of https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/examples/paged_llm_v1.py . Once the real model compiles, I'll substitute it.

stellaraccident commented 3 months ago

Thanks, I'm also able to compile for llvm-cpu with --iree-opt-strip-assertions. edit: specifically with llvm/torch-mlir#3277 too

This upstream patch removes these assertions and implements a more direct lowering (no more switchy stuff): https://github.com/llvm/torch-mlir/pull/3319

ScottTodd commented 3 months ago

Latest attempt:


Compile for Vulkan: D:\dev\projects\iree-build\tools\iree-compile D:\tmp\open_llama_3b_v2_f16_decode_only.mlir --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=turing-unknown-unknown -o /tmp/open_llama_3b_v2_f16_vulkan_decode_only_17339b.vmfb --iree-hal-executable-debug-level=3

Run on Vulkan: D:\dev\projects\iree-build\tools\iree-run-module --module=D:\tmp\open_llama_3b_v2_f16_vulkan_decode_only_17339b.vmfb --device=vulkan --input=4x1xi64 --input=4xi64 --input=4xi64 --input=4x1xi64 --input=1x2662400xf32 --parameters=model=D:\dev\projects\iree-data\huggingface\open_llama_3b_v2_gguf\open-llama-3b-v2-f16.gguf

Vulkan output:

EXEC @decode_bs4
result[0]: hal.buffer_view
4x1x32000xf32=[[NAN NAN NAN NAN NAN NAN NAN ...

Compile for CPU: D:\dev\projects\iree-build\tools\iree-compile D:\tmp\open_llama_3b_v2_f16_decode_only.mlir --iree-hal-target-backends=llvm-cpu -o /tmp/open_llama_3b_v2_f16_llvmcpu_decode_only_17339b.vmfb --iree-hal-executable-debug-level=3

Run on CPU: D:\dev\projects\iree-build\tools\iree-run-module --module=D:\tmp\open_llama_3b_v2_f16_llvmcpu_decode_only_17339b.vmfb --device=local-task --input=4x1xi64 --input=4xi64 --input=4xi64 --input=4x1xi64 --input=1x2662400xf32 --parameters=model=D:\dev\projects\iree-data\huggingface\open_llama_3b_v2_gguf\open-llama-3b-v2-f16.gguf

CPU crashes inside a dispatch (iree_elf_call_i_ppp).


Will trace execution and look at individual dispatches to go deeper.

ScottTodd commented 2 months ago

Currently debugging a runtime crash in decode still with @rsuderman .

We're suspecting that the in-place scatter operations are writing out of bounds. The exported programs had a sequence of scatters back to back so Rob has a branch (https://github.com/rsuderman/sharktank/tree/rework_update) that makes the key value store updates use a single scatter (if I'm understanding correctly). The model fails to compile after those changes.

I have a reduced test case of just a single index_put_ (in place operation that lowers to scatter and uses torch.overwrite.tensor.contents) here: https://gist.github.com/ScottTodd/df0d426a351a6737e16f507b187a210b . Looks like an issue in the torch-mlir lowering since it reproduces with torch-mlir-opt --pass-pipeline="builtin.module(func.func(torch-decompose-complex-ops,convert-torch-to-tmtensor))"

ScottTodd commented 2 months ago

A different reduced test (IR here, starting from the full llama model) was hitting an assert while compiling: https://gist.github.com/ScottTodd/366fe4b993c3d8e9776c40eddc4a6493

some debugging around the callstack also pointed at scatter ops:

--- areOpsFusable ---
  producer:
%48 = iree_linalg_ext.scatter dimension_map = [0, 1, 2, 3] unique_indices(false) ins(%expanded_27, %47 : tensor<1x1x1x1x32x100xf16>, tensor<1x4xi32>) outs(%expanded_21 : tensor<?x26x2x16x32x100xf16>) {
^bb0(%arg7: f16, %arg8: f16):
  iree_linalg_ext.yield %arg7 : f16
} -> tensor<?x26x2x16x32x100xf16>
  consumer:
%51 = iree_linalg_ext.scatter {__root_op__ = 17 : i64} dimension_map = [0, 1, 2, 3] unique_indices(false) ins(%expanded_35, %50 : tensor<1x1x1x1x32x100xf16>, tensor<1x4xi32>) outs(%48 : tensor<?x26x2x16x32x100xf16>) {
^bb0(%arg7: f16, %arg8: f16):
  iree_linalg_ext.yield %arg7 : f16
} -> tensor<?x26x2x16x32x100xf16>

I'm not sure if that is worth debugging further, may have been a buggy test case reduction. Going to follow up on the minimal index_put_ compilation error above next.

ScottTodd commented 2 months ago

Filed https://github.com/llvm/torch-mlir/issues/3433 for the index_put_ lowering that fails. Not sure if that's unique to our reduced test cases or if it appears in the full model too. Building out more test coverage and confidence in the operations e2e will help anyways.

ScottTodd commented 2 months ago

A different reduced test (IR here, starting from the full llama model) was hitting an assert while compiling: https://gist.github.com/ScottTodd/366fe4b993c3d8e9776c40eddc4a6493

This occurs in the full model too. Can work around it by disabling all dispatch region fusions (add a return false around here). Should file a reproducer upstream - the compiler must not crash (assert) on valid input. If the input is invalid then we'd need to update the frontend (torch-mlir / iree-turbine / sharktank).

ScottTodd commented 2 months ago

Compiling with --iree-input-demote-i64-to-i32 works around the runtime crash with the decode() function. We're trying to update the model definition (in the Python source) to use i32 while also digging into why the runtime crashes with i64.

ScottTodd commented 2 months ago

Tried to change dtypes in the model from i64 to i32 (https://github.com/nod-ai/sharktank/compare/main...ScottTodd:llama-i32?expand=1), ran into errors compiling after export like this:

~/iree-build/tools/iree-compile ~/scratch/open_llama_3b_v2_f16_i32more3_1block.mlir -o ~/scratch/open_llama_3b_v2_f16_i32more3_1block_asan.vmfb --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-link-embedded=false --iree-llvmcpu-sanitize=address
/home/scotttodd/scratch/open_llama_3b_v2_f16_i32more3_1block.mlir:11259:11: error: 'arith.cmpi' op requires all operands to have the same type
    %43 = torch.aten.index.Tensor %0, %42 : !torch.vtensor<[2048,50],complex<f32>>, !torch.list<optional<vtensor>> -> !torch.vtensor<[4,1,50],complex<f32>>
          ^
/home/scotttodd/scratch/open_llama_3b_v2_f16_i32more3_1block.mlir:11259:11: note: see current operation: %3741 = "arith.cmpi"(%arg275, %3740) <{predicate = 2 : i64}> : (i32, i64) -> i1

It sounds like https://github.com/iree-org/iree/pull/17696 fixes decode crashes while still using i32 types.

ScottTodd commented 2 months ago

Confirmed that these patches help

All together, I see decode appearing to work (outputs appear sensible and aligned with prefill). Can continue to validate.

ScottTodd commented 1 month ago

Ideas for next steps / follow-up tasks:

ScottTodd commented 1 month ago

Still seeing a crash in decode on Windows with these args:

iree-run-module \
  --module=/tmp/open_llama_3b_v2/open-llama-3b-v2-f16_cpu.vmfb \
  --function=decode_bs4 \
  --device=local-task \
  --input=4x1xi64=0 \
  --input=4xi64=1 \
  --input=4xi64=1 \
  --input=4x1xi64=0,1,2,3 \
  --input=1x2662400xf16 \
  --parameters=model=/tmp/open_llama_3b_v2/open-llama-3b-v2-f16.gguf

I'll wrap all my repro steps (documented here: https://github.com/nod-ai/sharktank/pull/69) into a script and run that script across my machines. Hopefully just a case of needing the cache (that --input=1x2662400xf16 arg) to be populated.