Open renxida opened 3 hours ago
Usually the kvcache arg looks like
%arg4: !torch.tensor<[?,2662400],f16>
and is the last arg in decode_bsX and prefill_bsX
But when I export ONLY bs=1, I see 50+ arguments, most of which looking like:
bs=1
%arg3: !torch.tensor<[1,2048,32,100],f16>, %arg4: !torch.tensor<[1,2048,32,100],f16>, %arg5: !torch.tensor<[1,2048,32,100],f16>, %arg6: !torch.tensor<[1,2048,32,100],f16>,
Interestingly, --bs=1,4 produces normal looking kvcache args.
--bs=1,4
Instead of many small arrays, the kvcache should be accepted as one big array, as expected by e.g. shortfin. Having many small arrays results in an error.
https://asciinema.org/a/UapPmn4h1wwuTxM6cptAYyXcr
The cache state params are created here for export_paged_llm_v1.py
https://github.com/nod-ai/SHARK-Platform/blob/4a49a847c873f4940bb0786283c3ddf25c4c6da3/sharktank/sharktank/examples/export_paged_llm_v1.py#L120-L133
The cache state is passed to export_program here: https://github.com/nod-ai/SHARK-Platform/blob/4a49a847c873f4940bb0786283c3ddf25c4c6da3/sharktank/sharktank/examples/export_paged_llm_v1.py#L144-L148
run this script https://gist.github.com/renxida/749d6a26cf94d05ba9db6342064aaa3f
You should see the same thing as:
or if you prefer pure text https://gist.github.com/renxida/ca262a5110701bbc024359026426952a
This is found while I was debugging the kvcache nan issue in shortfin llm serving. I would really love to have bs1 so I can isolate the issue.
@aviator19941 and @archana-ramalingam because I'm hoping how this script works is still fresh in your mind
What I see
Usually the kvcache arg looks like
and is the last arg in decode_bsX and prefill_bsX
But when I export ONLY
bs=1
, I see 50+ arguments, most of which looking like:Interestingly,
--bs=1,4
produces normal looking kvcache args.What I hoped to see
Instead of many small arrays, the kvcache should be accepted as one big array, as expected by e.g. shortfin. Having many small arrays results in an error.
Watch me replicate it
https://asciinema.org/a/UapPmn4h1wwuTxM6cptAYyXcr
Context
The cache state params are created here for export_paged_llm_v1.py
https://github.com/nod-ai/SHARK-Platform/blob/4a49a847c873f4940bb0786283c3ddf25c4c6da3/sharktank/sharktank/examples/export_paged_llm_v1.py#L120-L133
The cache state is passed to export_program here: https://github.com/nod-ai/SHARK-Platform/blob/4a49a847c873f4940bb0786283c3ddf25c4c6da3/sharktank/sharktank/examples/export_paged_llm_v1.py#L144-L148
How to replicate
run this script https://gist.github.com/renxida/749d6a26cf94d05ba9db6342064aaa3f
You should see the same thing as:
https://asciinema.org/a/UapPmn4h1wwuTxM6cptAYyXcr
or if you prefer pure text https://gist.github.com/renxida/ca262a5110701bbc024359026426952a
This is found while I was debugging the kvcache nan issue in shortfin llm serving. I would really love to have bs1 so I can isolate the issue.