dottxt-ai / outlines

Structured Text Generation
https://dottxt-ai.github.io/outlines/
Apache License 2.0
8.33k stars 425 forks source link

Continuous generation in Outlines #667

Open rlouf opened 7 months ago

rlouf commented 7 months ago

I am opening this issue to roughly sketch the next big milestone for Outlines, tentatively called "continuous generation". There are many rough edges still, and open questions.

The first goal is to allow sampling of sequences like these:

from outlines import generate, models

model = models.transformers("mistalai/Mistral-7B-v0.1")
generator = generate.text(model)

sequence = "What are the most popular types of vehicles?\n"
for i in range(6):
      sequence += f"{i}, "
      sequence += generator(sequence, stop_at=["\n"])
      sequence += "\n"

By "sampling these sequences" I mean being able to run, for instance, beam search and optimize the sequence as a whole rather than each generation separately.

All we have to do is to return a Sequence object instead of a string, with the following attributes and methods:

class Sequence:
    token_ids: torch.Tensor
    weights: torch.Tensor
    kv_cache: Tuple
    tokenizer: Tokenizer

    def __str__(self):
        return tokenizer.decode(token_ids)

Sequence should have the same feel as a string. Besides being able to print it, we should be able to slice it, add it to another string, another sequence, etc. and carry on:

class Sequence:
    ...
    def __getitem__(self, key):
        if isinstance(key, int):
            # Just return the character? There's not much more we can do here.
        if isintance(key, slice):
           # Different behavior depending on whether `start` is 0. If `start = 0` we can 
           # keep part of the KV Cache.  Otherwise we need to re-compute the KV 
           # Cache i.e. consider the `Sequence` as a new prompt. 
           #
           # We will likely need to split tokens. For instance if we call `sequence[:10]` and 
           # 10 is the letter `m` in `formida`. In this case we can encode and append `afor` 
           # to the previous token ids. Edge cases should automatically be handled when 
           # aligning prompt and generation. 
    def __add__(self, other):
         if isinstance(other, str):
            # Signal that KV cache + logprob need to be re-computed
         if isinstance(other, Sequence):
             # Concatenate token_ids
             # Concatenate logprobs
             # Signal that KV Cache after `other` needs to be recomputed

This should be enough to bring Outlines at feature-parity with other DSLs, while not being a DSL.

cpfiffer commented 7 months ago

It may also be interesting to get the join token likelihood, if available. I'm not super familiar with outlines but I'd love to be able to compare Sequences probabilistic.

rlouf commented 7 months ago

We could store that in addition to the sequence weights (which can be, but are not necessarily, the log-probability of the sequence).

jeffreyftang commented 7 months ago

Hi @rlouf, I was directed towards this issue by @lapp0 as a prerequisite issue for #657. I'm interested in contributing, but would like to get a sense of the scope of work involved so that I don't make promises I can't keep.

miftahmoha commented 6 months ago

I'm also interested, currently working on it right now.

rlouf commented 6 months ago

Great! It is fairly involved and there are many important design decisions that need to be made, and we need to handle computation of the KV cache after concatenating text to a previous generation.

don't hesitate to open a draft PR asap so I can give some feedback early on.

rlouf commented 6 months ago

would like to get a sense of the scope of work involved so that I don't make promises I can't keep.

It is fairly involved, interleaving function calls should be easier to implement though.

lucasavila00 commented 6 months ago

LmScript, a graphical interface for Outlines programs, makes heavy usage of continuous generation.

We currently re-send the accumulated prompt for every generation call and handle the chat template on our end.

Better performance for continuous generation would be highly appreciated

roberthoenig commented 5 months ago

Super excited for this feature!

One note: It'd be great if continuous generation is implemented so that intermediate outputs can be processed and reused during generation:

sequence = "What are the most popular names of vehicles and the length of their names?\n"
for i in range(6):
      sequence += f"{i}, "
      vehicle_name_gen = generator(sequence, stop_at=["\n"])
      name_len = process(len, vehicle_name_gen)   # `process` would be part of the outlines API and execute the given function during generation
      sequence += vehicle_name_gen + ",  " + name_len + " characters long."  
      sequence += "\n"