turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.28k stars 243 forks source link

ExLlamaV2Cache_8bit does not work with multiple_caches.py example #215

Closed lopuhin closed 7 months ago

lopuhin commented 7 months ago

Thanks for a great library and for providing multiple_caches.py example, it's really helpful to build inflight batching HTTP server on top.

I tried replacing ExLlamaV2Cache with ExLlamaV2Cache_8bit in the example to save memory in multiple_caches.py without doing other changes, and this resulted in an error:

Traceback (most recent call last):
  File "/home/user/path/multiple_caches.py", line 136, in <module>
    logits = model.forward(inputs, caches, input_mask = None).float().cpu()
  File "/home/user/path/exllamav2/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/path/exllamav2/venv/lib/python3.10/site-packages/exllamav2/model.py", line 534, in forward
    result, last_state = self._forward(input_ids = input_ids,
  File "/home/user/path/exllamav2/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/path/exllamav2/venv/lib/python3.10/site-packages/exllamav2/model.py", line 655, in _forward
    x = module.forward(x, cache = cache, attn_mask = attn_mask, past_len = past_len, loras = loras)
  File "/home/user/path/exllamav2/venv/lib/python3.10/site-packages/exllamav2/attn.py", line 427, in forward
    batch_keys, batch_values = cache[i].get_kv_state(self.layer_idx, batch_size, 0, past_len)
  File "/home/user/path/exllamav2/venv/lib/python3.10/site-packages/exllamav2/cache.py", line 186, in get_kv_state
    if width > 0: ext_c.fp8_to_fp16(self.key_states[layer_idx], temp_key_state, batch_size, offset, width)
TypeError: '>' not supported between instances of 'tuple' and 'int'

I'm using exllamav2 installed from pre-built wheel at https://github.com/turboderp/exllamav2/releases/tag/v0.0.10 using python 3.10 and CUDA 12.1 on Linux.

I tried a few naive fixes but they crashed or provided incorrect generations.

Do you think it could be that more changes are required in multiple_caches.py in order to use the 8 bit cache?

turboderp commented 7 months ago

Nope, sorry for the delay. But this was a bug in the attention function. Fixed with the latest commit.

lopuhin commented 7 months ago

Great, thanks for the fix, it works 👍 I confirm it allows to fit more cache into memory, at a slight runtime performance penalty.