mistralai / mistral-inference

Official inference library for Mistral models
https://mistral.ai/
Apache License 2.0
9.49k stars 835 forks source link

How is The 131K Attention Span Achieved? #71

Open ThePerfectComputer opened 9 months ago

ThePerfectComputer commented 9 months ago

The Mistral 7B paper claims a theoretical attention span of 131K tokens(via propagating information up through layers with GQA) for Mistral 7B. I'm trying to figure out how this is achieved in practice. The trick seems to be on line 128, with the branch if positions.shape[0] > 1:, which would typically be taken when the model is first called. From my understanding, taking this branch would compute k/v values for all provided tokens, which could then propagate information for an initial prompt state of up to 131K tokens throughout the model layers(due to the staggered nature of GQA). The line would never be evaluated again as additional tokens are subsequently added to the model cache one at a time as can be seen on line 332 logits = model.forward(next_token[:, None], torch.LongTensor([cur_pos]).to(next_token)) . I will note the the cache's default size is 4096, which also seems to give us the sliding attention window the paper refers to. All 32 transformer block layers of Mistral seem to compute attention scores using only the cache whenever only one token is provided to the model.forward() call, otherwise, the attention scores are computed without the cache. Is my current understanding correct? Paper link: https://ar5iv.labs.arxiv.org/html/2310.06825#S2.F1 Code where scores are generated directly from cache: https://github.com/mistralai/mistral-src/blob/main/one_file_ref.py#L125-L140