Open ScottTodd opened 3 months ago
For spriv-vulkan backend here's the minimal repro
func.func @torch_add(%arg0: !torch.vtensor<[1,1,?,?],i1>, %arg1: !torch.vtensor<[4,1,1,?],i1>) -> !torch.vtensor<[4, 1, ?, ?],i1> {
%int1 = torch.constant.int 1
%2 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[1,1,?,?],i1>, !torch.vtensor<[4,1,1,?],i1>, !torch.int -> !torch.vtensor<[4,1,?,?],i1>
return %2 : !torch.vtensor<[4,1,?,?],i1>
}
error: spirv.IAdd
op operand #0 must be 8/16/32/64-bit integer
but got i1
.
Pulling in some of the comments from the chat. For the CPU backend there are two options:
iree-opt-demote-i64-to-32
flag; these models deal with large no. and truncating might not be the good strategy.--iree-opt-strip-assertions
flag; assertions hanging around, strips them and compiles the model.For spirv-vulkan backend I have posted the minimal repro above.
I think that assert can be safely dropped at the torch level in the same way as the broadcast asserts: when in strict mode from torch, the invariant being checked for dynamic legality must be true (torch enforces it).
Thanks, I'm also able to compile for llvm-cpu
with --iree-opt-strip-assertions
.
edit: specifically with https://github.com/llvm/torch-mlir/pull/3277 too
llvm-cpu
llvm-cpu
with --iree-opt-strip-assertions
(sounds like we should fix the frontend to omit asserts from aten.view
)vulkan-spirv
(some i1 handling)Confirmed GGUF version 2 on https://huggingface.co/SlyEcho/open_llama_3b_v2_gguf/tree/main?show_file_info=open-llama-3b-v2-f16.gguf
Found version upgrade instructions at https://github.com/ggerganov/llama.cpp?tab=readme-ov-file#prepare-and-quantize, requires building llama.cpp/examples/quantize.cpp
from source?
Found version 3 in https://huggingface.co/QuantFactory/Meta-Llama-3-8B-GGUF/tree/main?show_file_info=Meta-Llama-3-8B.Q8_0.gguf, going to try that: huggingface-cli download --local-dir /tmp/huggingface/llama3_8B QuantFactory/Meta-Llama-3-8B-GGUF Meta-Llama-3-8B.Q8_0.gguf
python -m sharktank.examples.export_paged_llm_v1 --hf-dataset=llama3_8B_q8_0 --output=/tmp/llama3_8B_q8_0.mlir
Exporting decode_bs4 Traceback (most recent call last): File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_dynamo\utils.py", line 1766, in run_node return getattr(args[0], node.target)(*args[1:], kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch\utils_stats.py", line 20, in wrapper return fn(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^ File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_subclasses\fake_tensor.py", line 896, in __torch_dispatch__ return self.dispatch(func, types, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_subclasses\fake_tensor.py", line 1241, in dispatch return self._cached_dispatch_impl(func, types, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_subclasses\fake_tensor.py", line 974, in _cached_dispatch_impl output = self._dispatch_impl(func, types, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_subclasses\fake_tensor.py", line 1393, in _dispatch_impl return decomposition_table[func](args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_refs__init.py", line 4547, in view return _reshape_view_helper(a, *shape, allow_copy=False) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_refs__init.py", line 3629, in _reshape_view_helper shape = utils.infer_size(shape, a.numel()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_prims_common\init__.py", line 891, in infer_size if d == -1: ^^^^^^^ File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch\init.py", line 374, in bool_ return self.node.bool() ^^^^^^^^^^^^^^^^^ File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch\fx\experimental\symnode.py", line 432, in bool return self.guard_bool("", 0) ^^^^^^^^^^^^^^^^^^^^^^ File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch\fx\experimental\sym_node.py", line 374, in guard_bool r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch\fx\experimental\recording.py", line 231, in wrapper return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch\fx\experimental\symbolic_shapes.py", line 4138, in evaluate_expr raise self._make_data_dependent_error( torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0, -1) (unhinted: Eq(u0, -1)). (Size-like symbols: none)
Potential framework code culprit (scroll up for full backtrace): File "D:\dev\projects\sharktank.venv\Lib\site-packages\torch_prims_common__init__.py", line 891, in infer_size if d == -1:
https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF is also version 2
Looked into versions, found that GGUF version 3 just "added" big endian support, so we should be able to support version 2 and version 3. Trying that with open_llama_3b_v2_gguf
again
Vulkan:
iree-run-module --module=/tmp/open_llama_3b_v2_f16_vulkan_decode_only.vmfb --device=vulkan --input=4x1xi64 --input=4xi64 --input=4xi64 --input=4x1xi64 --input=1x2662400xf32 --parameters=model=/tmp/huggingface/open_llama_3b_v2_gguf/open-llama-3b-v2-f16.gguf
EXEC @decode_bs4
D:\dev\projects\iree\runtime\src\iree\hal\command_buffer_validation.c:363: INVALID_ARGUMENT; source and target ranges overlap within the same buffer; stack:
0x00007ff6e1f7238f iree-run-module <iree_hal_command_buffer_copy_buffer_validation+0x23f> (D:\dev\projects\iree\runtime\src\iree\hal\command_buffer_validation.c:361)
0x00007ff6e1f67e18 iree-run-module <iree_hal_command_buffer_copy_buffer+0xa8> (D:\dev\projects\iree\runtime\src\iree\hal\command_buffer.c:458)
0x00007ff6e1f00a72 iree-run-module <iree_hal_module_command_buffer_copy_buffer+0xc2> (D:\dev\projects\iree\runtime\src\iree\modules\hal\module.c:798)
0x00007ff6e1f16642 iree-run-module <iree_vm_shim_rrIrII_v+0x82> (D:\dev\projects\iree\runtime\src\iree\vm\shims.c:65)
0x00007ff6e1f19754 iree-run-module <iree_vm_native_module_issue_call+0x84> (D:\dev\projects\iree\runtime\src\iree\vm\native_module.c:342)
CPU (local-task): assert hit, --trace-execution
output: https://gist.github.com/ScottTodd/8c215d943f6f27fa480a8ba5ed328cb3
iree-run-module --module=/tmp/open_llama_3b_v2_f16_cpu.vmfb --device=local-sync --input=4x1xi64 --input=4xi64 --input=4xi64 --input=4x1xi64 --input=1x2662400xf32 --parameters=model=/tmp/huggingface/open_llama_3b_v2_gguf/open-llama-3b-v2-f16.gguf --function=decode_bs4 --trace-execution
...
[module.decode_bs4$async+000410C2] %r0 = vm.call @hal.command_buffer.create(%r266(!hal.device/0x0000015AD90B22D0), %i206(1), %i206(1), %i83(0))
[module.decode_bs4$async+000410D6] vm.call @hal.command_buffer.copy_buffer(%r0(!hal.command_buffer/0x0000015CAECF9450), %r4(!hal.buffer/0x0000015C7B094080), %i84(0), %r4(!hal.buffer/0x0000015C7B094080), %i84(0), %i100(10649600))
--- assert hit ---
ucrtbase.dll!00007ff8674d286e() (Unknown Source:0)
iree-run-module.exe!iree_abort() Line 26 (d:\dev\projects\iree\runtime\src\iree\base\assert.h:26)
iree-run-module.exe!iree_vm_buffer_deinitialize(iree_vm_buffer_t * buffer) Line 79 (d:\dev\projects\iree\runtime\src\iree\vm\buffer.c:79)
iree-run-module.exe!iree_vm_bytecode_module_destroy(void * self) Line 152 (d:\dev\projects\iree\runtime\src\iree\vm\bytecode\module.c:152)
iree-run-module.exe!iree_vm_context_release_modules(iree_vm_context_t * context, unsigned __int64 start, unsigned __int64 end) Line 288 (d:\dev\projects\iree\runtime\src\iree\vm\context.c:288)
iree-run-module.exe!iree_vm_context_destroy(iree_vm_context_t * context) Line 362 (d:\dev\projects\iree\runtime\src\iree\vm\context.c:362)
iree-run-module.exe!iree_tooling_run_module_with_data(iree_vm_instance_t * instance, iree_string_view_t default_device_uri, iree_const_byte_span_t module_contents, iree_allocator_t host_allocator, int * out_exit_code) Line 422 (d:\dev\projects\iree\runtime\src\iree\tooling\run_module.c:422)
Next: figure out the runtime errors. Miscompile? Going over some runtime limits? local-sync and local-task have different errors. Look at the VM IR and see if anything stands out.
I created a mock version of open_llama_3b_v2_f16.mlir
here: https://gist.github.com/ScottTodd/ee0cd9d6ab80e4814edad353235cf664. That just returns 1 for all values (no math/kernels/etc.).
Compile with:
iree-compile \
mock_open_llama_3b_v2_f16.mlir \
--iree-hal-target-backends=llvm-cpu \
-o mock_open_llama_3b_v2_f16_cpu.vmfb
Run prefill with:
iree-run-module \
--module=mock_open_llama_3b_v2_f16_cpu.vmfb \
--device=local-sync \
--function=decode_bs4 \
--input=4x1xi64 \
--input=4xi64 \
--input=4x1xi64 \
--input=1x2662400xf32 \
--parameters=model=open-llama-3b-v2-f16.gguf
Run decode with:
iree-run-module \
--module=mock_open_llama_3b_v2_f16_cpu.vmfb \
--device=local-sync \
--function=decode_bs4 \
--input=4x1xi64 \
--input=4xi64 \
--input=4xi64 \
--input=4x1xi64 \
--input=1x2662400xf32 \
--parameters=model=open-llama-3b-v2-f16.gguf
I'm planning on loading that into Python and standing up an IREE version of https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/examples/paged_llm_v1.py . Once the real model compiles, I'll substitute it.
Thanks, I'm also able to compile for
llvm-cpu
with--iree-opt-strip-assertions
. edit: specifically with llvm/torch-mlir#3277 too
This upstream patch removes these assertions and implements a more direct lowering (no more switchy stuff): https://github.com/llvm/torch-mlir/pull/3319
Latest attempt:
Compile for Vulkan: D:\dev\projects\iree-build\tools\iree-compile D:\tmp\open_llama_3b_v2_f16_decode_only.mlir --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=turing-unknown-unknown -o /tmp/open_llama_3b_v2_f16_vulkan_decode_only_17339b.vmfb --iree-hal-executable-debug-level=3
Run on Vulkan: D:\dev\projects\iree-build\tools\iree-run-module --module=D:\tmp\open_llama_3b_v2_f16_vulkan_decode_only_17339b.vmfb --device=vulkan --input=4x1xi64 --input=4xi64 --input=4xi64 --input=4x1xi64 --input=1x2662400xf32 --parameters=model=D:\dev\projects\iree-data\huggingface\open_llama_3b_v2_gguf\open-llama-3b-v2-f16.gguf
Vulkan output:
EXEC @decode_bs4
result[0]: hal.buffer_view
4x1x32000xf32=[[NAN NAN NAN NAN NAN NAN NAN ...
Compile for CPU: D:\dev\projects\iree-build\tools\iree-compile D:\tmp\open_llama_3b_v2_f16_decode_only.mlir --iree-hal-target-backends=llvm-cpu -o /tmp/open_llama_3b_v2_f16_llvmcpu_decode_only_17339b.vmfb --iree-hal-executable-debug-level=3
Run on CPU: D:\dev\projects\iree-build\tools\iree-run-module --module=D:\tmp\open_llama_3b_v2_f16_llvmcpu_decode_only_17339b.vmfb --device=local-task --input=4x1xi64 --input=4xi64 --input=4xi64 --input=4x1xi64 --input=1x2662400xf32 --parameters=model=D:\dev\projects\iree-data\huggingface\open_llama_3b_v2_gguf\open-llama-3b-v2-f16.gguf
CPU crashes inside a dispatch (iree_elf_call_i_ppp
).
Will trace execution and look at individual dispatches to go deeper.
Currently debugging a runtime crash in decode still with @rsuderman .
We're suspecting that the in-place scatter operations are writing out of bounds. The exported programs had a sequence of scatters back to back so Rob has a branch (https://github.com/rsuderman/sharktank/tree/rework_update) that makes the key value store updates use a single scatter (if I'm understanding correctly). The model fails to compile after those changes.
I have a reduced test case of just a single index_put_
(in place operation that lowers to scatter and uses torch.overwrite.tensor.contents
) here: https://gist.github.com/ScottTodd/df0d426a351a6737e16f507b187a210b . Looks like an issue in the torch-mlir lowering since it reproduces with torch-mlir-opt --pass-pipeline="builtin.module(func.func(torch-decompose-complex-ops,convert-torch-to-tmtensor))"
A different reduced test (IR here, starting from the full llama model) was hitting an assert while compiling: https://gist.github.com/ScottTodd/366fe4b993c3d8e9776c40eddc4a6493
some debugging around the callstack also pointed at scatter ops:
--- areOpsFusable ---
producer:
%48 = iree_linalg_ext.scatter dimension_map = [0, 1, 2, 3] unique_indices(false) ins(%expanded_27, %47 : tensor<1x1x1x1x32x100xf16>, tensor<1x4xi32>) outs(%expanded_21 : tensor<?x26x2x16x32x100xf16>) {
^bb0(%arg7: f16, %arg8: f16):
iree_linalg_ext.yield %arg7 : f16
} -> tensor<?x26x2x16x32x100xf16>
consumer:
%51 = iree_linalg_ext.scatter {__root_op__ = 17 : i64} dimension_map = [0, 1, 2, 3] unique_indices(false) ins(%expanded_35, %50 : tensor<1x1x1x1x32x100xf16>, tensor<1x4xi32>) outs(%48 : tensor<?x26x2x16x32x100xf16>) {
^bb0(%arg7: f16, %arg8: f16):
iree_linalg_ext.yield %arg7 : f16
} -> tensor<?x26x2x16x32x100xf16>
I'm not sure if that is worth debugging further, may have been a buggy test case reduction. Going to follow up on the minimal index_put_
compilation error above next.
Filed https://github.com/llvm/torch-mlir/issues/3433 for the index_put_
lowering that fails. Not sure if that's unique to our reduced test cases or if it appears in the full model too. Building out more test coverage and confidence in the operations e2e will help anyways.
A different reduced test (IR here, starting from the full llama model) was hitting an assert while compiling: https://gist.github.com/ScottTodd/366fe4b993c3d8e9776c40eddc4a6493
This occurs in the full model too. Can work around it by disabling all dispatch region fusions (add a return false
around here). Should file a reproducer upstream - the compiler must not crash (assert) on valid input. If the input is invalid then we'd need to update the frontend (torch-mlir / iree-turbine / sharktank).
Compiling with --iree-input-demote-i64-to-i32
works around the runtime crash with the decode()
function. We're trying to update the model definition (in the Python source) to use i32 while also digging into why the runtime crashes with i64.
Tried to change dtypes in the model from i64 to i32 (https://github.com/nod-ai/sharktank/compare/main...ScottTodd:llama-i32?expand=1), ran into errors compiling after export like this:
~/iree-build/tools/iree-compile ~/scratch/open_llama_3b_v2_f16_i32more3_1block.mlir -o ~/scratch/open_llama_3b_v2_f16_i32more3_1block_asan.vmfb --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-link-embedded=false --iree-llvmcpu-sanitize=address
/home/scotttodd/scratch/open_llama_3b_v2_f16_i32more3_1block.mlir:11259:11: error: 'arith.cmpi' op requires all operands to have the same type
%43 = torch.aten.index.Tensor %0, %42 : !torch.vtensor<[2048,50],complex<f32>>, !torch.list<optional<vtensor>> -> !torch.vtensor<[4,1,50],complex<f32>>
^
/home/scotttodd/scratch/open_llama_3b_v2_f16_i32more3_1block.mlir:11259:11: note: see current operation: %3741 = "arith.cmpi"(%arg275, %3740) <{predicate = 2 : i64}> : (i32, i64) -> i1
It sounds like https://github.com/iree-org/iree/pull/17696 fixes decode crashes while still using i32 types.
Confirmed that these patches help
All together, I see decode appearing to work (outputs appear sensible and aligned with prefill). Can continue to validate.
Ideas for next steps / follow-up tasks:
Still seeing a crash in decode
on Windows with these args:
iree-run-module \
--module=/tmp/open_llama_3b_v2/open-llama-3b-v2-f16_cpu.vmfb \
--function=decode_bs4 \
--device=local-task \
--input=4x1xi64=0 \
--input=4xi64=1 \
--input=4xi64=1 \
--input=4x1xi64=0,1,2,3 \
--input=1x2662400xf16 \
--parameters=model=/tmp/open_llama_3b_v2/open-llama-3b-v2-f16.gguf
I'll wrap all my repro steps (documented here: https://github.com/nod-ai/sharktank/pull/69) into a script and run that script across my machines. Hopefully just a case of needing the cache (that --input=1x2662400xf16
arg) to be populated.
Goal
Run a llama model from https://github.com/nod-ai/sharktank/blob/main/sharktank/sharktank/models/llama/llama.py through IREE
Starting with
open_llama_3b_v2_f16_gguf
since we have that in docs. Could try another model or data type but should eventually all sorts of variants working.Approach
https://github.com/nod-ai/sharktank/tree/main/sharktank/sharktank/examples has a few files already:
Next steps from there could be
iree-compile
and run it usingiree-run-module
paged_llm_v1.py
that could eitherWorklog
Export -> try compile entire program ("prefill" and "decode")
to generate https://sharkpublic.blob.core.windows.net/sharkpublic/scotttodd/issue_reports/open_llama_3b_v2_f16.mlir
llvm-cpu
with default flags:iree-compile open_llama_3b_v2_f16.mlir --iree-hal-target-backends=llvm-cpu -o /tmp/open_llama_3b_v2_f16_cpu.vmfb --iree-hal-executable-debug-level=3 --iree-hal-dump-executable-files-to=/tmp/open_llama_3b_v2_f16_cpu
. That got stuck compiling afterLLVMCPUVectorTransferLowering
: https://github.com/iree-org/iree/issues/17244#issuecomment-2099201467vulkan-spirv
with default flags:iree-compile open_llama_3b_v2_f16.mlir --iree-hal-target-backends=vulkan-spirv -o /tmp/open_llama_3b_v2_f16_vulkan.vmfb
. That hit two different spirv codegen issues: https://github.com/iree-org/iree/issues/17304failed to legalize operation 'arith.extui'
withi1 -> i64
on CPU,'spirv.IAdd'
withi1
on Vulkan)Next: continue triaging compilation errors for prefill.
Export and run just "decode"
iree-compile open_llama_3b_v2_f16_decode_only.mlir --iree-hal-target-backends=vulkan-spirv --iree-vulkan-target-triple=turing-unknown-unknown -o /tmp/open_llama_3b_v2_f16_vulkan_decode_only.vmfb --iree-hal-executable-debug-level=3
iree-run-module
I need the inputs and a parameter file--input=4xi64 --input=4xi64 --input=4xi64 --input=4xi64 --input=1x2662400xf32
(need to verify)huggingface-cli download --local-dir /tmp/open_llama_3b_v2_gguf SlyEcho/open_llama_3b_v2_gguf
(that folder then contains/tmp/open_llama_3b_v2_gguf/open-llama-3b-v2-f16.gguf
)iree-run-module --module=/tmp/open_llama_3b_v2_f16_vulkan_decode_only.vmfb --device=vulkan --input=4xi64 --input=4xi64 --input=4xi64 --input=4xi64 --input=1x2662400xf32 --parameters=model=/tmp/open_llama_3b_v2_gguf/open-llama-3b-v2-f16.gguf
produces this error:iree\runtime\src\iree\io\formats\gguf\gguf_parser.c:678: UNIMPLEMENTED; GGUF format version 2 is unsupported; expected version 3
Next: try upgrading GGUF version 2 to 3? Load from safetensors? Convert to IRPA?