turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.28k stars 243 forks source link

cache.clone() is not creating a copy of the cache #226

Closed hidoba closed 7 months ago

hidoba commented 7 months ago

cache.clone() seems to create a new cache that forgets the original cache.

The following example produces rubbish.

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from exllamav2 import(
    ExLlamaV2,
    ExLlamaV2Config,
    ExLlamaV2Cache,
    ExLlamaV2Tokenizer,
)

from exllamav2.generator import (
    ExLlamaV2BaseGenerator,
    ExLlamaV2Sampler
)

import torch
import random

model_directory =  "model/llm/"
config = ExLlamaV2Config()
config.model_dir = model_directory
config.prepare()
model = ExLlamaV2(config)
print("Loading model: " + model_directory)
model.load()
tokenizer = ExLlamaV2Tokenizer(config)

settings_proto = ExLlamaV2Sampler.Settings()
settings_proto.temperature = 0.8
settings_proto.top_p = 0.75

prompt_base = "Paris is a capital of"

ids = tokenizer.encode(prompt_base)
cache = ExLlamaV2Cache(model, max_seq_len = 16000)
model.forward(ids, cache, preprocess_only = True)

cache2 = cache.clone()

while(True):
    logits = model.forward(ids[:,-1:], cache2, input_mask = None).float().cpu()
    r = random.random()
    token, _, _ = ExLlamaV2Sampler.sample(logits, settings_proto, ids, r, tokenizer)
    ids = torch.cat([ids, token], dim = 1)
    if token.item() == tokenizer.newline_token_id or cache.current_seq_len == cache.max_seq_len:
        break

output = tokenizer.decode(ids)[0]
print(output)
Loading model: model/llm/
Paris is a capital of #16548.

If you replace the line cache2 = cache.clone() by cache2 = cache, it will correctly continue the prompt Paris is a capital of France ...

turboderp commented 7 months ago

Thank you for the bug report and especially for the code to easily reproduce it. :star: The cache was being cloned, but it also tracks the current sequence length, and that wasn't being copied across to the clone. Fixed in the latest commit.