abetlen / llama-cpp-python

Python bindings for llama.cpp
https://llama-cpp-python.readthedocs.io
MIT License
7.85k stars 938 forks source link

Add n_keep parameter to LLama constructor to enable Streaming-LLM #954

Open twoletters opened 10 months ago

twoletters commented 10 months ago

A recent paper by Meta/MIT/CMU proposed StreamingLLM, a simple yet efficient solution to enable "infinite" context. Better yet, the implementation in llama.cpp is as trivial as changing the n_keep value with option --keep as discussed in this issue. Unfortunately, the high-level API of llama-cpp-python does not support the keep/n_keep parameter.

It should be simple to add the parameter to the high-level API, ideally in the constructor for class Llama and to pass it along to function llama_cpp.llama_load_model_from_file as part of parameter lparams here.

Limour-dev commented 8 months ago

Maybe like this?

    def kv_cache_seq_ltrim(self, n_keep, n_discard=256, n_past=-1, im_start=None):
        if n_past < 0:
            n_past = self.n_tokens
        if im_start is not None:  # [<|im_start|>, name, nl]
            lps = compute_lps_array(im_start)
            _idx = kmp_search(self.input_ids, im_start, n_keep + n_discard, n_past, lps)
            if _idx >= n_keep:  # 其实是大于等于 n_keep + n_discard
                n_discard = _idx - n_keep  # 截断到最近的 im_start 序列结构
            else:
                _idx = kmp_search(self.input_ids, im_start, n_keep, n_past, lps)
                if _idx >= n_keep:
                    n_keep = _idx + len(im_start)  # 至少保留一个 im_start 序列结构
        self._ctx.kv_cache_seq_rm(-1, n_keep, n_keep + n_discard)
        self._ctx.kv_cache_seq_shift(0, n_keep + n_discard, n_past, -n_discard)
        self.input_ids[n_keep:n_past - n_discard] = self.input_ids[n_keep + n_discard:n_past]
        self.n_tokens = n_past - n_discard

    def eval_t(self, tokens, n_keep=4, n_discard=256, im_start=None):
        if self._n_ctx < self.n_tokens + len(tokens):
            tmp_n_discard = max(n_discard, self.n_tokens + len(tokens) - self._n_ctx)
            self.kv_cache_seq_ltrim(n_keep, tmp_n_discard, im_start)
        for i in range(0, len(tokens), self.n_batch):
            pass