Open twoletters opened 10 months ago
Maybe like this?
def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
if n_past < 0:
n_past = self.n_tokens
if im_start is not None: # [<|im_start|>, name, nl]
lps = compute_lps_array(im_start)
_idx = kmp_search(self.input_ids, im_start, n_keep + n_discard, n_past, lps)
if _idx >= n_keep: # 其实是大于等于 n_keep + n_discard
n_discard = _idx - n_keep # 截断到最近的 im_start 序列结构
else:
_idx = kmp_search(self.input_ids, im_start, n_keep, n_past, lps)
if _idx >= n_keep:
n_keep = _idx + len(im_start) # 至少保留一个 im_start 序列结构
self._ctx.kv_cache_seq_rm(-1, n_keep, n_keep + n_discard)
self._ctx.kv_cache_seq_shift(0, n_keep + n_discard, n_past, -n_discard)
self.input_ids[n_keep:n_past - n_discard] = self.input_ids[n_keep + n_discard:n_past]
self.n_tokens = n_past - n_discard
def eval_t(self, tokens, n_keep=4, n_discard=256, im_start=None):
if self._n_ctx < self.n_tokens + len(tokens):
tmp_n_discard = max(n_discard, self.n_tokens + len(tokens) - self._n_ctx)
self.kv_cache_seq_ltrim(n_keep, tmp_n_discard, im_start)
for i in range(0, len(tokens), self.n_batch):
pass
A recent paper by Meta/MIT/CMU proposed StreamingLLM, a simple yet efficient solution to enable "infinite" context. Better yet, the implementation in llama.cpp is as trivial as changing the
n_keep
value with option--keep
as discussed in this issue. Unfortunately, the high-level API of llama-cpp-python does not support thekeep
/n_keep
parameter.It should be simple to add the parameter to the high-level API, ideally in the constructor for class
Llama
and to pass it along to functionllama_cpp.llama_load_model_from_file
as part of parameterlparams
here.