vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
22.59k stars 3.18k forks source link

[Feature] Support for SelfExtend-style context expansion #2349

Open creatorrr opened 6 months ago

creatorrr commented 6 months ago

In the paper LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning, the authors describe a method to extend the context-window of any rope-based model without fine-tuning at inference time. I haven't gotten around to testing it myself but the results reported in the paper seem game-changing.

How could we add support for this in vllm? Their algorithm:

q, k, v # queries, keys, and values
seq_len, pos # input sequence length, position_idx
g_size, w_size = G, w_n

# normal self-attention
ngb_q = apply_pos_emcode(q, pos)
ngb_k = apply_pos_emcode(k, pos)
ngb_attn = matmul(ngb_q, ngb_k)
ngb_attn = causal_mask(ngb_attn)

# grouped self-attention
g_pos = pos // g_size # the floor operation
shift = w_size - w_size // g_size
s_g_pos = g_pos + shift
g_q = apply_pos_emcode(q, s_g_pos)
g_k = apply_pos_emcode(k, g_pos)
g_attn = matmul(g_q, g_k)
g_attn = causal_mask(g_attn)

g_mask = tril(ones([seq_len-w_size, seq_len-w_size]))
mask = ones([seq_len, seq_len])
mask[w_size:, :-w_size] -= g_mask

attn = where(mask, ngb_attn, g_attn) # merge by replacement

attn_weights = softmax(attn)
output = matmul(attn_weights, v)
alphanlp commented 6 months ago

is this code run ok?

creatorrr commented 5 months ago

@alphanlp no, just pseudocode of their algorithm

Ki6an commented 5 months ago

+1

1787648106 commented 4 months ago

As the paper mentioned, self-Extend do not support flash-attn.

Mooler0410 commented 3 months ago

As the paper mentioned, self-Extend do not support flash-attn.

We recently added flash-attention support for Selfextend

darwinlc commented 3 months ago

+1 hope to see adding selfextend in vllm.

creatorrr commented 3 months ago

@zhuohan123 thoughts? Pointers to how I can contribute will be awesome :)