Open renxida opened 2 weeks ago
currently taking a look at how the kvcache size is calculated in shortfin
Looks like expected 1048576
is wrong and it should be
32 ( transformer layer count ) 2 (kv) 32 (attention heads) 16 (tokens per block) 128 (head dimension) = 4194304
taking a look at mlir next. also, @stbaione was able to get a different error on GPU so I think I'm running into an IREE cpu issue or possibly a sharktank export issue.
func.func @prefill_bs1(%arg0: !torch.vtensor<[1,?],si64> {iree.abi.affinity = #hal.device.promise<@device_0>}, %arg1: !torch.vtensor<[1],si64> {iree.abi.affinity = #hal.device.promise<@__device_0>}, %arg2: !torch.vtensor<[1,?],si64> {iree.abi.affinity = #hal.device.promise<@device_0>}, %arg3: !torch.tensor<[?,1048576],f16>) -> !torch.vtensor<[1,?,128256],f16> attributes {torch.assume_strict_symbolic_shapes} {
the wrong cache shape is in for both _bs1 and _bs4
Confirmed that cache dim 1 is 1048576 in export_paged_llm
Found that the kvcache is created with:
Creating paged kv cache with:
transformer_block_count: 32
attn_head_count: 8
attn_head_dim: 128
cache_partition_count: 2
block_seq_stride: 16
dtype: torch.float16
device: None
shard_count: 1
So it's probably because the # attention heads got divided by 4 somewhere.
ok i think i figured this one out. I think there's something going on with grouped attention. The kvcache is shared in groups of 4 attention heads, such that the kvcache effectively is supposed to have 4x fewer attention heads.
LlamaHParams has attention_head_count_kv=8, attention_head_count=32
LlamaHParams(model_arch='llama', context_length=131072, embedding_length=4096, block_count=32, feed_forward_length=14336, attention_head_count=32, attn_head_dim=128, attention_layer_norm_rms_epsilon=9.999999747378752e-06, attention_head_count_kv=8, rope_dimension_count=128, rope_freq_base=500000.0, expert_count=0, expert_used_count=0)
And
export_paged_llm_v1.py exports a config.json with attention_head_count=32
But:
shortfin really expected the attention_head_count in the exported json to refer to the attention_head_count_kv
Superceded by #405
Error:
It seems to be expecting the kvcache page to be 4x smaller than it actually is:
Repro steps:
Download weights:
Download tokenizer & compile:
https://gist.github.com/renxida/9d974a909f401f451b58e9dc2052af16
(this also prints the command needed to serve & run client)