HabanaAI / vllm-fork

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
43 stars 58 forks source link

[Performance]: context aware HpuRotaryEmbedding implementation #166

Closed huijjj closed 3 weeks ago

huijjj commented 3 months ago

Proposal to improve performance

https://github.com/HabanaAI/vllm-fork/blob/37ca17f0097dae0a03fee6936062871ec49e2351/vllm/hpu/rotary_embed.py#L80-L115

Current HpuRotaryEmbedding forward implementation slices the cached buffer([:seq_len]) performing index_select(==[position]) with positions during prefill phase. seq_len is simply obtained by the examining the dimention of the inputs as currently prefills do not have context(pre-computed blocks) and it is safe to assume that sequence length is identical to one of the dimension of the input. So slicing can be done without any data dependancies, it can be done only with the tensor metadata which is the tensor shape in this case. However, when you think of the cases when prefills can have context(such as prefix caching or chunked prefill) this is no longer the case. Obtaining seq_len requires data-dependant flow as dimension of inputs are no longer necessarily identical to the seq_len. seq_len must be fetched from the maximum value of the positions which require some instructions like torch.max(positions). And what's worse is that the as the slice index got dynamic we have to recompile the slicing every single time. This actually slows down the inference a lot. So I'd like to remove the slicing since I have no idea why this is needed. If there is hidden intention behind it, please let me know so that I can come up with a better idea.

Also, I'd like to remove the if branch checking whether if this position is already cached or not. VLLM caches cos and sin values with respect to the model configuration in the initialization step using the value specified in max_position_embeddings. So it is already cached to its maximum size and we do not need to check if given position is already cached or not.

Summarizing all the changes I mentioned above, codes are:

    def forward(self, positions: torch.Tensor, query: torch.Tensor,
                key: torch.Tensor):
        if FusedRoPE is None:
            return self.fallback_impl(positions, query, key)
        if query.dim() == 2:
            query = query.unsqueeze(0)
        if key.dim() == 2:
            key = key.unsqueeze(0)
        if positions.dim() == 1:
            positions = positions.unsqueeze(0)

        query = query.reshape(
            (query.shape[0], query.shape[1], query.shape[2] // self.head_size,
             self.head_size))
        key = key.reshape((key.shape[0], key.shape[1],
                           key.shape[2] // self.head_size, self.head_size))

        cos = self.cos_cached[positions].unsqueeze(2).to(dtype=query.dtype)
        sin = self.sin_cached[positions].unsqueeze(2).to(dtype=query.dtype)
        query = FusedRoPE.apply(query, cos, sin, 0)
        key = FusedRoPE.apply(key, cos, sin, 0)
        return query.reshape(
            (query.shape[0], query.shape[1],
             query.shape[2] * query.shape[3])), key.reshape(
                 (key.shape[0], key.shape[1], key.shape[2] * key.shape[3]))

If it is okay with you, I'll add this to my PR about enabling prefix caching.

michalkuligowski commented 2 months ago

@huijjj @kzawora-intel I think this issue can be closed as it is implemented in https://github.com/HabanaAI/vllm-fork/pull/162 , can you confirm

huijjj commented 2 months ago

@michalkuligowski Yes, I have included the change in #162, you can close this issue. However, I am still not sure if it is okay to make this change as I do not know the purpose of the original code yet. So please take a look and leave me comments there if needed. Thank you.