vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
26.4k stars 3.87k forks source link

Attention sliding window #3385

Open caiom opened 6 months ago

caiom commented 6 months ago

In Hugging Face "eager" Mistral implementation, a sliding window of size 2048 will mask 2049 tokens. This is also true for flash attention. In the current vLLM implementation a window of 2048 will mask 2048 tokens:

import torch
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

attn_bias = BlockDiagonalCausalMask.from_seqlens([4096])
attn_bias = attn_bias.make_local_attention(2048)
mask = attn_bias._create_block_mask([4096, 4096])
print(torch.sum(mask == 0, dim=1))

Output: tensor([ 1, 2, 3, ..., 2048, 2048, 2048])

The output should be: tensor([ 1, 2, 3, ..., 2049, 2049, 2049])

Context: https://github.com/huggingface/transformers/issues/29623

cadedaniel commented 6 months ago

Good catch, this should be fixed. PR contributions are welcome!

caiom commented 6 months ago

@cadedaniel One issue is that paged attention kernel is not compatible with many Models including the original Mistral if I'm not mistaken. The kernel won't accept a window size of 4096 (4097 tokens).

What I suggest is to fix this issue but still accept models like Mistral while yielding a warning.

Something like:

sliding_window = sliding_window + 1
if sliding_window % X != 0:
    sliding_window = (sliding_window // X) * X
    print(f'[WARNING]: Window size not compatible with Paged Attention, rounding down to {sliding_window}')

If you are OK with that I can create the PR.

cadedaniel commented 6 months ago

Can you help me understand the impact of this divergence on the quality of the model?

caiom commented 6 months ago

I think it is minimal. Until 5 days ago HF would use a different convention depending on the backend. Flash Attn would use 4097 tokens while eager would use 4096 tokens, like vLLM. We could print the warning only if the difference is larger than one?

cadedaniel commented 6 months ago

I feel that since the windowed information flow over 4k tokens is unlikely to change much by a single token, this is OK to leave without a warning. But if I'm wrong then we should add a warning / fix.

caiom commented 4 months ago

vLLM SDPA is also affected:

import torch
from vllm.attention.backends.torch_sdpa import _make_sliding_window_bias

mask = _make_sliding_window_bias([4096], 2048, torch.float32)
print(torch.sum(mask[0][0, ...] == 0, dim=1))

Output: tensor([ 1, 2, 3, ..., 2048, 2048, 2048])

caiom commented 4 months ago

transformers 4.40.1 for reference:

from transformers.modeling_attn_mask_utils import AttentionMaskConverter

attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=2048)
mask = attn_mask_converter.to_causal_4d(1, 4096, 4096, torch.float32)
print(torch.sum(mask[0, 0, ...] == 0, dim=1))

Output: tensor([ 1, 2, 3, ..., 2049, 2049, 2049])

davidgxue commented 1 month ago

Hey wanted to check: is this problem fixed? Both my mistral and phi 3 model cannot use FA2

INFO 07-26 19:19:35 selector.py:170] Cannot use FlashAttention-2 backend due to sliding window.
INFO 07-26 19:19:35 selector.py:54] Using XFormers backend.

and Im using the latest version of vLLM