turboderp / exllama

A more memory-efficient rewrite of the HF transformers implementation of Llama for use with quantized weights.
MIT License
2.74k stars 215 forks source link

RuntimeError: temp_state buffer is too small #244

Closed daniel-kukiela closed 1 year ago

daniel-kukiela commented 1 year ago

Hi,

I'm hitting probably a bug when I'm trying to run an inference with a prompt of size 2k tokens and more and with a batch whose token count exceeded 2k tokens (for example a batch of 2 prompts of 1k+ tokens each).

The error:

Traceback (most recent call last):
  File "test_hparams_gptq.py", line 440, in <module>
    generated_text = model.generate(context, max_length=200)
  File "test_hparams_gptq.py", line 286, in generate
    sequences = self.pipeline(
  File "/usr/local/lib/python3.8/dist-packages/transformers/pipelines/text_generation.py", line 204, in __call__
    return super().__call__(text_inputs, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/pipelines/base.py", line 1129, in __call__
    return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
  File "/usr/local/lib/python3.8/dist-packages/transformers/pipelines/base.py", line 1136, in run_single
    model_outputs = self.forward(model_inputs, **forward_params)
  File "/usr/local/lib/python3.8/dist-packages/transformers/pipelines/base.py", line 1035, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
  File "/usr/local/lib/python3.8/dist-packages/transformers/pipelines/text_generation.py", line 265, in _forward
    generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
  File "/usr/local/lib/python3.8/dist-packages/auto_gptq/modeling/_base.py", line 443, in generate
    return self.model.generate(**kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/generation/utils.py", line 1642, in generate
    return self.sample(
  File "/usr/local/lib/python3.8/dist-packages/transformers/generation/utils.py", line 2724, in sample
    outputs = self(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py", line 810, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py", line 698, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py", line 426, in forward
    hidden_states = self.mlp(hidden_states)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py", line 220, in forward
    down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/auto_gptq/nn_modules/qlinear/qlinear_exllama.py", line 167, in forward
    out = ext_q4_matmul(x.half(), self.q4, self.width)
  File "/usr/local/lib/python3.8/dist-packages/auto_gptq/nn_modules/qlinear/qlinear_exllama.py", line 38, in ext_q4_matmul
    q4_matmul(x, q4, output)
RuntimeError: temp_state buffer is too small

I found a very similar issue in the exllama project: https://github.com/turboderp/exllama/issues/211 and there's a fix in exllama: https://github.com/turboderp/exllama/commit/c16cf49c3f19e887da31d671a713619c8626484e. Is this something that could be fixed in your implementation, too?

daniel-kukiela commented 1 year ago

Heh, I put it in the wrong repository, sorry :)