nod-ai / SHARK-Platform

SHARK Inference Modeling and Serving
Apache License 2.0
12 stars 25 forks source link

llama3 8b f16 kvcache dimension mismatch on shortfin #401

Open renxida opened 2 weeks ago

renxida commented 2 weeks ago

Error:

It seems to be expecting the kvcache page to be 4x smaller than it actually is:

INVALID_ARGUMENT; tensor shape dimension 1 mismatch; expected 1048576 but have 4194304; expected shape `256x1048576`, actual shape `256x4194304`; while invoking native function hal.buffer_view.assert; while calling import; 
``` [2024-10-31 12:13:38.339] [info] [service.py:390] INVOKE ProgramFunction(prefill_bs1$async: 0rrrrrr_r): 0: [1, 16] 1: [1] 2: [1, 1] 3: [256, 4194304] [2024-10-31 12:13:38.339] [error] [service.py:415] Fatal error in prefetch invocation Traceback (most recent call last): File "/home/xidaren2/SHARK-Platform/shortfin/python/shortfin_apps/llm/components/service.py", line 396, in run (logits,) = await fn(*args, fiber=self.fiber) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: ValueError: shortfin_iree-src/runtime/src/iree/modules/hal/utils/buffer_diagnostics.c:221: INVALID_ARGUMENT; tensor shape dimension 1 mismatch; expected 1048576 but have 4194304; expected shape `256x1048576`, actual shape `256x4194304`; while invoking native function hal.buffer_view.assert; while calling import; [ 0] bytecode module.prefill_bs1$async:3530 /home/xidaren2/xshortfin/llama3_8b/export/model.mlir:293:294 /home/xidaren2/SHARK-Platform/shortfin/src/shortfin/support/iree_helpers.h:315: UNKNOWN; Unhandled exception: Traceback (most recent call last): File "/home/xidaren2/SHARK-Platform/shortfin/python/_shortfin/asyncio_bridge.py", line 79, in _sf_maybe_run File "/home/xidaren2/miniforge3/envs/env1/lib/python3.12/asyncio/events.py", line 103, in _run File "/home/xidaren2/SHARK-Platform/shortfin/python/shortfin_apps/llm/components/generate.py", line 78, in run RuntimeError: Async exception on ): argmax(): incompatible function arguments. The following argument types are supported: 1. argmax(input: _shortfin_default.lib.array.device_array, axis: int = -1, out: _shortfin_default.lib.array.device_array | None = None, *, keepdims: bool = False, device_visible: bool = False) -> _shortfin_default.lib.array.device_array Invoked with types: NoneType ```

Repro steps:

Download weights:

az storage blob download   --container-name halo-models   --account-name sharkblobs   --name llm-dev/llama3_8b/8b_f16.irpa   --file model.irpa --account-key=[redacted]

az storage blob download   --container-name halo-models   --account-name sharkblobs   --name llm-dev/llama3_8b/llama8b_f16.gguf   --file model.gguf --account-key=[redacted]

Download tokenizer & compile:

https://gist.github.com/renxida/9d974a909f401f451b58e9dc2052af16

(this also prints the command needed to serve & run client)

renxida commented 2 weeks ago

currently taking a look at how the kvcache size is calculated in shortfin

renxida commented 2 weeks ago

Looks like expected 1048576 is wrong and it should be

32 ( transformer layer count ) 2 (kv) 32 (attention heads) 16 (tokens per block) 128 (head dimension) = 4194304

renxida commented 2 weeks ago

taking a look at mlir next. also, @stbaione was able to get a different error on GPU so I think I'm running into an IREE cpu issue or possibly a sharktank export issue.

renxida commented 2 weeks ago

func.func @prefill_bs1(%arg0: !torch.vtensor<[1,?],si64> {iree.abi.affinity = #hal.device.promise<@device_0>}, %arg1: !torch.vtensor<[1],si64> {iree.abi.affinity = #hal.device.promise<@__device_0>}, %arg2: !torch.vtensor<[1,?],si64> {iree.abi.affinity = #hal.device.promise<@device_0>}, %arg3: !torch.tensor<[?,1048576],f16>) -> !torch.vtensor<[1,?,128256],f16> attributes {torch.assume_strict_symbolic_shapes} {

the wrong cache shape is in for both _bs1 and _bs4

renxida commented 2 weeks ago

model.mlir: https://gist.github.com/renxida/69edfbbac1f24257bead6adb43ca0947

renxida commented 2 weeks ago

Confirmed that cache dim 1 is 1048576 in export_paged_llm

Found that the kvcache is created with:

Creating paged kv cache with:
transformer_block_count: 32
attn_head_count: 8
attn_head_dim: 128
cache_partition_count: 2
block_seq_stride: 16
dtype: torch.float16
device: None
shard_count: 1

So it's probably because the # attention heads got divided by 4 somewhere.

renxida commented 2 weeks ago

ok i think i figured this one out. I think there's something going on with grouped attention. The kvcache is shared in groups of 4 attention heads, such that the kvcache effectively is supposed to have 4x fewer attention heads.

LlamaHParams has attention_head_count_kv=8, attention_head_count=32

LlamaHParams(model_arch='llama', context_length=131072, embedding_length=4096, block_count=32, feed_forward_length=14336, attention_head_count=32, attn_head_dim=128, attention_layer_norm_rms_epsilon=9.999999747378752e-06, attention_head_count_kv=8, rope_dimension_count=128, rope_freq_base=500000.0, expert_count=0, expert_used_count=0)

And

export_paged_llm_v1.py exports a config.json with attention_head_count=32

But:

shortfin really expected the attention_head_count in the exported json to refer to the attention_head_count_kv

renxida commented 2 weeks ago

Superceded by #405