turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.71k stars 283 forks source link

[REQUEST] Passing cache to and from generate() function for use in loop #678

Closed cmunna0052 closed 6 days ago

cmunna0052 commented 6 days ago

Problem

I am very new to exllamav2, so my apologies if this feature can already be achieved through other means, but the goal is to produce a sequence of generations that progressively fill in a template. I supply an initial prompt x0, generate x1 tokens, add that to the prompt, add x2 more predefined tokens to the prompt, supply x0 + x1 + x2 to the model to generate x1 more tokens, and then repeat. This whole process is repeated for a number of examples.

The question is how to ensure that I can properly utilize the kv-cache during this process. In Huggingface, it is straightforward to achieve this with the following loop:

generation_config = GenerationConfig(max_new_tokens=1, use_cache=True, return_dict_in_generate=True)
for ii in range(len(examples)):
        past_key_values = None
        for jj, pos in enumerate(generation_positions):
                input_batch = torch.cat([truncated_input[ii], example_output[:, :pos]], dim=1)
                attention_mask = input_batch != self.tokenizer.pad_token_id
                output_dict = model.generate(
                    input_ids=input_batch,
                    attention_mask=attention_mask,
                    past_key_values=past_key_values,
                    generation_config=generation_config,
                    pad_token_id=self.tokenizer.eos_token_id,
                )
                sequence = output_dict["sequences"]
                past_key_values = output_dict["past_key_values"]

                generated_tokens = sequence[:, -1]
                generated_text = [self.tokenizer.decode(x, skip_special_tokens=True) for x in generated_tokens]
                curr_generated_labels[:, jj] = torch.tensor([int(x) for x in generated_text])
                example_output[:, pos] = generated_tokens 

This works very well, since the cache is supplied directly to generate() and then retrieved immediately in output_dict.

I'm not sure how I might achieve something similar in exllama. I am currently trying to set it up in the following way:

gen_settings = ExLlamaV2Sampler.Settings.greedy()
generator = ExLlamaV2DynamicGenerator(
    model = model, cache = cache, tokenizer = tokenizer, paged=False
)
for ii in range(len(examples)):
        for jj, pos in enumerate(generation_positions):
                  input_batch = torch.cat([truncated_input[ii], example_output[:, :pos]], dim=1)
                  job = ExLlamaV2DynamicJob(
                          input_ids = input_batch,
                          gen_settings = gen_settings,
                          max_new_tokens = 1,
                          identifier = idx,
                      )
                  generator.enqueue(job)
                  generated_text = [result["full_completion"] for result in generator.iterate()]
                  curr_generated_labels[:, jj] = torch.tensor([int(x) for x in generated_text])
                  example_output[:, pos] = generated_tokens 

However, I have no idea if it is possible to ensure the cache is doing what I want it to, or if there is a straightforward way to update the ExLlamaV2Cache object with the results of each generation, and then reset it between examples.

Solution

A simple method to ensure the cache can be passed to and from a generation process.

Alternatives

No response

Explanation

Makes use of the cache much easier for large numbers of inferences with specific formats.

Examples

No response

Additional context

No response

Acknowledgements

turboderp commented 6 days ago

The dynamic generator automatically automatically manages the cache and reuses the results of previous jobs, for as many input tokens as it can match between an old and a new job. So if you're building a context bit by bit, all you have to do is make sure you're not editing the past. For instance:

At this point the cache would be able to resume generation from any portion of ABCDEFGHIJKLMNOP (starting with A) or ABCDEFGHI123456. You could also launch multiple jobs at once starting with ABCDE and they would all reference the same portion of the cache (deduplication).

So basically, it's all automated. If you're building a context bit by bit, just start each new generation with the entire context-so-far and the generator will only process the bits that don't line up with what's already been processed. When the cache fills up, the oldest (least recently referenced pages) are evicted first.

There's a bit more info here.

cmunna0052 commented 6 days ago

That is super helpful, thanks!