mistralai / mistral-inference

Official inference library for Mistral models
https://mistral.ai/
Apache License 2.0
9.51k stars 841 forks source link

one_file_ref.py attention has an O(seqlen^2) matrix multiplication when prefilling #38

Closed Aniruddha-Deb closed 11 months ago

Aniruddha-Deb commented 11 months ago

Lines 129-143 in one_file_ref.py multiplies the complete query-key matrices with each other, if we are prefilling the key-value cache. The sliding window mask is applied only after this multiplication


        if positions.shape[0] > 1:
            # prefill
            key, value = repeat_kv(xk, xv, self.repeats)
        else:
            cur_pos = positions[-1].item() + 1
            key, value = repeat_kv(self.cache_k[:bsz, :cur_pos, ...], self.cache_v[:bsz, :cur_pos, ...], self.repeats)

        query = xq.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)
        # scores : [bsz, n_heads, seqlen | 1, seqlen]
        scores = torch.matmul(query, key.transpose(2, 3)) * self.scale
        # this operation is O(seqlen^2), and not O(seqlen*sliding_window))

        if mask is not None:
            scores += mask[None, None, ...]

This seems inefficient for prompt sizes > sliding window length, and can be improved by just using the attention implementation in mistral/model.py directly (which uses xformers' memory_efficient_attention).

timlacroix commented 11 months ago

yeah the idea for this file was to be as simple as possible for people to understand what computation is being done if they want to re-implement it in any framework they want.

This is a bit easier to understand and parse than the cuda kernels in xformers or flashattention ;)