nod-ai / SHARK-Platform

SHARK Inference Modeling and Serving
Apache License 2.0
11 stars 22 forks source link

paged_llm with `--bs=1` exports func.func with too many kvcache arguments #312

Open renxida opened 3 hours ago

renxida commented 3 hours ago

What I see

Usually the kvcache arg looks like

 %arg4: !torch.tensor<[?,2662400],f16>

and is the last arg in decode_bsX and prefill_bsX

But when I export ONLY bs=1, I see 50+ arguments, most of which looking like:

 %arg3: !torch.tensor<[1,2048,32,100],f16>, %arg4: !torch.tensor<[1,2048,32,100],f16>, %arg5: !torch.tensor<[1,2048,32,100],f16>, %arg6: !torch.tensor<[1,2048,32,100],f16>, 

Interestingly, --bs=1,4 produces normal looking kvcache args.

What I hoped to see

Instead of many small arrays, the kvcache should be accepted as one big array, as expected by e.g. shortfin. Having many small arrays results in an error.

Watch me replicate it

https://asciinema.org/a/UapPmn4h1wwuTxM6cptAYyXcr

Context

The cache state params are created here for export_paged_llm_v1.py

https://github.com/nod-ai/SHARK-Platform/blob/4a49a847c873f4940bb0786283c3ddf25c4c6da3/sharktank/sharktank/examples/export_paged_llm_v1.py#L120-L133

The cache state is passed to export_program here: https://github.com/nod-ai/SHARK-Platform/blob/4a49a847c873f4940bb0786283c3ddf25c4c6da3/sharktank/sharktank/examples/export_paged_llm_v1.py#L144-L148

How to replicate

run this script https://gist.github.com/renxida/749d6a26cf94d05ba9db6342064aaa3f

You should see the same thing as:

https://asciinema.org/a/UapPmn4h1wwuTxM6cptAYyXcr

or if you prefer pure text https://gist.github.com/renxida/ca262a5110701bbc024359026426952a

This is found while I was debugging the kvcache nan issue in shortfin llm serving. I would really love to have bs1 so I can isolate the issue.

renxida commented 3 hours ago

@aviator19941 and @archana-ramalingam because I'm hoping how this script works is still fresh in your mind