marella / ctransformers

Python bindings for the Transformer models implemented in C/C++ using GGML library.
MIT License
1.76k stars 137 forks source link

GPTQ models are not respecting context_length or max_seq_len settings #191

Open chrsbats opened 7 months ago

chrsbats commented 7 months ago

No matter what I try I can't set the context_length of a GPTQ model. It's overridden by ExLLAMA, which then sets the cache size and context_length whatever it set as default (in this case 2048).

First problem is that its actually using max_seq_len to set the context_length and the Config dataclass doesn't include that field. Even if I monkey patch the config dataclass and set the Config

        model = "TheBloke/NeuralHermes-2.5-Mistral-7B-GPTQ"
        config = AutoConfig.from_pretrained(model)
        config.max_seq_len = 8000
        config.context_length = 8000
        config.config.max_seq_len = 8000
        config.config.context_length = 8000
        self.config = config
        self.llm = AutoModelForCausalLM.from_pretrained(model,config=self.config,local_files_only=True,max_seq_len=8000)
        self.llm.config.context_length = 8000
        self.llm.config.max_seq_len = 8000

None of these will change the context_length used by the GPTQ model because it uses the ExLLAMA config instead.

If I reach in and modify the ExLLAMA config after loading the model via

        self.llm._model.config.max_seq_len = 8000
        self.llm._model.config.max_input_len = 8000

It correctly sets the context_length that but its already allocated a cache size at 2048 and promptly crashes whenever you ask for a long response.

File ~/anaconda3/envs/orac/lib/python3.11/site-packages/exllama/model.py:369, in ExLlamaAttention.fused(self, hidden_states, cache, buffer, input_layernorm, lora)
    365 query_states = query_states.view(bsz, q_len, self.config.num_attention_heads, self.config.head_dim)
    367 # Get k, v with past
--> 369 key_states = cache.key_states[self.index].narrow(2, 0, past_len + q_len).narrow(0, 0, bsz)
    370 value_states = cache.value_states[self.index].narrow(2, 0, past_len + q_len).narrow(0, 0, bsz)
    372 # Repeat K/V heads if num_key_value_headsn_kv_heads < n_heads

RuntimeError: start (0) + length (2049) exceeds dimension size (2048).