Closed huijjj closed 3 weeks ago
@huijjj @kzawora-intel I think this issue can be closed as it is implemented in https://github.com/HabanaAI/vllm-fork/pull/162 , can you confirm
@michalkuligowski Yes, I have included the change in #162, you can close this issue. However, I am still not sure if it is okay to make this change as I do not know the purpose of the original code yet. So please take a look and leave me comments there if needed. Thank you.
Proposal to improve performance
https://github.com/HabanaAI/vllm-fork/blob/37ca17f0097dae0a03fee6936062871ec49e2351/vllm/hpu/rotary_embed.py#L80-L115
Current HpuRotaryEmbedding forward implementation slices the cached buffer(
[:seq_len]
) performingindex_select
(==[position]
) with positions during prefill phase.seq_len
is simply obtained by the examining the dimention of the inputs as currently prefills do not have context(pre-computed blocks) and it is safe to assume that sequence length is identical to one of the dimension of the input. So slicing can be done without any data dependancies, it can be done only with the tensor metadata which is the tensor shape in this case. However, when you think of the cases when prefills can have context(such as prefix caching or chunked prefill) this is no longer the case. Obtainingseq_len
requires data-dependant flow as dimension of inputs are no longer necessarily identical to theseq_len
.seq_len
must be fetched from the maximum value of thepositions
which require some instructions liketorch.max(positions)
. And what's worse is that the as the slice index got dynamic we have to recompile the slicing every single time. This actually slows down the inference a lot. So I'd like to remove the slicing since I have no idea why this is needed. If there is hidden intention behind it, please let me know so that I can come up with a better idea.Also, I'd like to remove the
if
branch checking whether if this position is already cached or not. VLLM caches cos and sin values with respect to the model configuration in the initialization step using the value specified inmax_position_embeddings
. So it is already cached to its maximum size and we do not need to check if given position is already cached or not.Summarizing all the changes I mentioned above, codes are:
If it is okay with you, I'll add this to my PR about enabling prefix caching.