Closed ispobock closed 4 months ago
Hi @lzhangzz @lvhan028 And do we have any plans to support token attention in TurboMind in the near future? Thanks.
Will this proposal conflict with turbomind's stateful inference?
Hi @lzhangzz @lvhan028 And do we have any plans to support token attention in TurboMind in the near future? Thanks.
No it make no sense to me.
Hi @lzhangzz @lvhan028 And do we have any plans to support token attention in TurboMind in the near future? Thanks.
No it make no sense to me.
ok
Will this proposal conflict with turbomind's stateful inference?
I think the issues here are the same as those mentioned in https://github.com/InternLM/lmdeploy/pull/1393#issuecomment-2041458974, maybe we can temporarily only consider non-interactive dialogue scenarios.
Will this proposal conflict with turbomind's stateful inference?
No conflict I think. In stateful inference, the history cache is prioritized. The block hash will involve all the prefix till this block (including the history tokens). But if the sequence is too long, the hash computation will be time consuming.
Sequence is composed of blocks, and block is the smallest unit of reuse, so the smallest unit for cache management and prefix matching should be block, not token. So we designed a new data structure Block Trie
, which is modified from Trie for block matching scenario. In Trie, each node stores a character, while in Block Trie
, each node stores a block. A hash_key is used to identify the block for prefix matching.
Performance:
Implementation and maintenance complexity
Here is a demo for Block Trie
implementation:
import time
import heapq
BLOCK_SIZE = 2
class Seq:
def __init__(self):
self.tokens = []
self.blocks = []
class BlockManager:
def __init__(self):
self.free_blocks = []
def get_free_num(self):
return len(self.free_blocks)
def allocate(self):
if len(self.free_blocks) == 0:
return None
return self.free_blocks.pop(0)
def add_to_free(self, free_blocks):
self.free_blocks.extend(free_blocks)
class Node:
def __init__(self):
self.children: dict[int: 'Node'] = {}
self.hash_key: int = -1 # hash of block
self.block = None # one block per node
self.tokens = [] # token ids in the block
self.last_access_time = -1
self.ref_count = 0
self.parent = None # for evict
class BlockTrie:
def __init__(self):
self.root = Node()
self.leaves = set([self.root])
self.block_manager = BlockManager()
def allocate_blocks(self, seq: Seq):
matched_blocks = self._match_or_insert(seq)
# handel not matched blocks
i = len(matched_blocks)
while (i*BLOCK_SIZE) < len(seq.tokens):
if self.block_manager.get_free_num() == 0:
# no free block, need to evict
evicted = self._evict(1)
if len(evicted) == 0:
print("insert failed: no enough free block")
break
block = self.block_manager.allocate()
assert block != None
seq.blocks.append(block)
i += 1
# O(L) for one seq
def _match_or_insert(self, seq: Seq):
matched_blocks = []
curr = self.root
i = 0
while (i+1)*BLOCK_SIZE <= len(seq.tokens):
# the last incomplete block is ignored
curr_tokens = seq.tokens[i*BLOCK_SIZE:(i+1)*BLOCK_SIZE]
key = hash(tuple(curr_tokens))
if key in curr.children.keys():
child = curr.children[key]
if child.tokens == curr_tokens:
# avoid hash collision
matched_blocks.append(child.block)
child.ref_count += 1
child.last_access_time = time.time()
curr = child
i += 1
else:
# hash collision, don't use cache
break
else:
# match failed, insert it
node = Node()
node.hash_key = key
if self.block_manager.get_free_num() == 0:
# no free block, need to evict
evicted = self._evict(1)
if len(evicted) == 0:
print("insert failed: no enough free block")
break
block = self.block_manager.allocate()
assert block != None
node.block = block
node.tokens = curr_tokens
node.ref_count += 1
node.last_access_time = time.time()
node.parent = curr
matched_blocks.append(node.block)
self.leaves.add(node)
if len(curr.children) == 0:
self.leaves.remove(curr)
curr.children[key] = node
curr = node
i += 1
seq.blocks = matched_blocks
return matched_blocks
def free_blocks(self, seq:Seq):
curr = self.root
i = 0
while (i+1)*BLOCK_SIZE <= len(seq.tokens):
curr_tokens = seq.tokens[i*BLOCK_SIZE:(i+1)*BLOCK_SIZE]
key = hash(tuple(curr_tokens))
if key in curr.children.keys():
child = curr.children[key]
if child.tokens == curr_tokens:
child.last_access_time = time.time()
child.ref_count -= 1
i += 1
curr = child
else:
break
else:
break
# for blocks not in cache, free it immediately
self.block_manager.add_to_free(seq.blocks[i:])
def _evict(self, need_num):
evicted_blocks = []
# TODO: use ordered set to optimize
leaves = list(filter(lambda x : x.ref_count == 0, self.leaves))
leaves = [(leaf.last_access_time, leaf) for leaf in leaves]
heapq.heapify(leaves)
for _ in range(need_num):
if len(leaves) == 0:
print("no enough blocks to evict")
break
_, leaf = heapq.heappop(leaves)
evicted_blocks.append(leaf.block)
del leaf.parent.children[leaf.hash_key]
self.leaves.remove(leaf)
if len(leaf.parent.children) == 0:
self.leaves.add(leaf.parent)
if leaf.parent.ref_count == 0:
heapq.heappush(leaves, (leaf.parent.last_access_time, leaf))
self.block_manager.add_to_free(evicted_blocks)
return evicted_blocks
if __name__ == "__main__":
bt = BlockTrie()
# total 5 blocks, each block contains BLOCK_SIZE tokens
bt.block_manager.add_to_free([1, 2, 3, 4, 5])
seq1 = Seq()
seq1.tokens = [1, 2, 3, 4, 5, 6, 7]
bt.allocate_blocks(seq1)
print(seq1.blocks) # [1, 2, 3, 4]
seq2 = Seq()
seq2.tokens = [1, 2, 3, 4]
bt.allocate_blocks(seq2)
print(seq2.blocks) # [1, 2] reused the first 2 blocks of seq1
seq3 = Seq()
seq3.tokens = [2, 3]
bt.allocate_blocks(seq3)
print(seq3.blocks) # [5] no reuse
bt.free_blocks(seq1)
seq4 = Seq()
seq4.tokens = [2, 3, 4, 5, 6, 7]
bt.allocate_blocks(seq4)
print(seq4.blocks) # [5, 4, 3] the first block reused seq3's block, last 2 blocks are used the last 2 freed blocks of seq1
bt.free_blocks(seq2)
seq5 = Seq()
seq5.tokens = [2, 4]
bt.allocate_blocks(seq5)
print(seq5.blocks) # [2] used the freed blocks of seq2
@lvhan028 @lzhangzz @grimoire @irexyc Do you have any suggestions?
Hi all @lvhan028 @lzhangzz @grimoire @irexyc After internal discussions, Block Trie combines the advantages of Radix Tree and Hash Table, and is easier to maintain in implementation compared to Radix Tree. Currently vLLM and RTP-LLM implement Automatic Prefix Caching
. And supporting general cache implementations has a high cost and questionable practicality. We think it's a good balance. We would like to know the community's opinions and suggestions on this solution. Thanks.
Regarding whether to support the interactive interface prefix cache in TurboMind, this can be discussed later. Since integration involves LlamaBatch, we also want to have a general understanding of when the community plans to refactor LlamaBatch so that we can better coordinate module development time. We hope that the merging process will not be so painful in the end.
Good to me.
For vllm, we tested the prefix caching and found there is almost 20% performance improvement on SharedGPT dataset with manually added system prompts. So this feature may have great benefits in scenarios with system prompts.
vllm without prefix caching:
Successful requests: 1000
Benchmark duration (s): 147.58
Total input tokens: 336509
Total generated tokens: 160646
Request throughput (req/s): 6.78
Input token throughput (tok/s): 2280.18
Output token throughput (tok/s): 1088.54
vllm with prefix caching:
Successful requests: 1000
Benchmark duration (s): 124.07
Total input tokens: 336509
Total generated tokens: 159917
Request throughput (req/s): 8.06
Input token throughput (tok/s): 2712.29
Output token throughput (tok/s): 1288.95
Reproduce procedure:
SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
"""
prompts = [SYSTEMPROMPT + prompt for prompt, in dataset]
2. Start vllm with or without `--enable-prefix-caching`
python -m vllm.entrypoints.openai.api_server --model /workdir/llm_models/llama2_13b_chat --trust-remote-code python -m vllm.entrypoints.openai.api_server --model /workdir/llm_models/llama2_13b_chat --trust-remote-code --enable-prefix-caching
3. Profile the performance
python benchmark_serving.py --backend vllm --host 0.0.0.0 --port 8000 --dataset /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --model /workdir/llm_models/llama2_13b_chat --tokenizer /workdir/llm_models/llama2_13b_chat/
Hi all @lvhan028 @lzhangzz @grimoire @irexyc @AllentDan A 20% increase in throughput is significant, especially when the system prompts template takes effect, particularly with longer system prompts. We have validated this in our internal scenarios and observed an almost 50% increase in throughput. Therefore, the benefits of this feature are substantial. Perhaps the community can prioritize this feature and provide some suggestions on the technical solution we proposed. Thanks.
Turbomind's stateful inference is built on top of block level caching with LRU eviction policy. So there will be no conflict with prefix caching.
The caching mechanism is implemented by SequenceManager
and BlockManager
together. Currently blocks can only be reused inside the same session. To add prefix caching we may add an extra layer in-between SequenceManager
and BlockManager
that explicitly manages a global pool of prefixes so that prefixes from other sessions can be recognized and reused.
The data structure that manages the prefixes (radix tree, hash table or simply a sorted vector) is better to be decoupled with sequence or block level objects. It focuses on virtual "prefix blocks of token" instead of real memory blocks (that's BlockManager
's job).
In the newest Turbomind engine, the smallest block_size
is 64
. The length of prefix(system prompts) is usually 100~200
. If we did block level reuse (like the Block Trie
mentioned above), the last partially matched block (at most 63 tokens, almost 30% of the prefix length) will be wasted. Maybe we need to figure out a new solution to reuse the last partially matched block.
Maybe we need to figure out a new solution to reuse the last partially matched block.
No need for that. It's possible to reduce the smallest block size to 16 in the future.
No need for that. It's possible to reduce the smallest block size to 16 in the future.
@lzhangzz When will this change plan to release?
No need for that. It's possible to reduce the smallest block size to 16 in the future.
And what's the side effect of reducing the smallest block size? Will it affect the performance?
When will this change plan to release?
Likely in May.
Will it affect the performance?
There may be slight degeneration in performance.
Motivation
Prefix caching is supported in many projects such as vllm, sglang and rtp-llm. Torch engine is going to support this feature in https://github.com/InternLM/lmdeploy/pull/1393. So we raise this issue to discuss the plan for Turbomind engine.
As discussed in https://github.com/InternLM/lmdeploy/pull/1393, there are mainly two approaches to implement this feature:
We prefer to implement this feature in Turbomind using the Hash Table approach at first because:
Existing Hash Tables implementation in other projects has a slightly higher time complexity for prefix matching and insertion compared to Radix Tree. Maybe we can use Hash Table to implement the feature first, and optimize the performance later if needed.
@lvhan028 @lzhangzz @grimoire @irexyc Do you have any suggestions? As this feature is mainly relates to the engine change, we may need to coordinate the order with the LlamaBatch refactor.
Related resources
No response
Additional context
No response