turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.2k stars 235 forks source link

Reverting/rolling back filter state #384

Open xonfour opened 3 months ago

xonfour commented 3 months ago

Hi there!

I'm using generate.stream_ex() to generate new chunks and I have it returning the top n probabilities. Based on that the probability distribution I want to decline/modify tokens or revert it back. this is easy without filters. But with filters I cannot find a way to do it. If I get it right it is currently implemented only "one way" with the feed_filters() call.

1) Is there a way to do this? 2) If not: What would be the best way to implement it? :)

turboderp commented 3 months ago

Yeah, the filters are inherently tricky to roll back. The filter implemented in lm-format-enforcer, for instance, has a whole complicated state machine behind it.

There are two approaches I can think of. One would be to restart the filters from an earlier point and then feed in all the tokens up to the point you want to roll back to. Or, depending on what you need exactly, you could skip calling feed_filters if you just want to be able to decline individual tokens.

One issue with stream_ex is that it's not guaranteed to return a token. It streams a chunk of output which is usually a token, but could be less or more in order to guarantee that only whole characters are streamed (in the case of multi-token characters) and that the frontend never has to roll back if, say, a stop condition breaks down to more than one token.

If you're only doing this one token at a time, it should be possible to do multi-turn sampling. I.e. I could add an interface to let you hook in a validator function that's called right after a token is picked by the sampler (under the current filter constraints), which returns an instruction to the generator to either:

The validator could easily have access to the probabilities and/or raw logits. That would be a pretty clean solution anyway. Do you think it would suffice for what you're doing?

xonfour commented 3 months ago

Yes, that sounds like a great plan. A validation hook (with a list of functions?) could do the trick. That way we won't have to touch the filters at all.

xonfour commented 3 months ago

I could try to implement that. Were should we keep the hooks? In the settings just like the filters?

turboderp commented 3 months ago

For this purpose I suppose you could introduce a blank token and override the sampler to pick that token based on some condition or other. It would still have to be inserted into the context, but you wouldn't have to feed anything to the filters if it decodes to an empty string.

There would be two main issues with it, I think:

In any case since it would be sort of a "post" filter, adding it to ExLlamaV2Sampler.Settings probably makes the most sense. I'm currently doing some cleanup/refactoring/type hinting, so I'd rather wait till that's done (tomorrow probably) before adding a significant feature like that.

xonfour commented 3 months ago

I think for now I will just stick to existing tokens like pad, whitespace or even newline. Because I'm forcing JSON, whitespace and newline might even be the best choice since it might be more natural for non-finetuned models because it looks like some kind of formatting it might have learned. but that is speculation...

Haha good to know you're refactoring, I just wanted to start ;) Give me a sign!

turboderp commented 3 months ago

Okay, I added something really basic to start with. To use it, define a function like so:

def my_hook(result):
    if result.sampled_prob.item() < 0.1:  # or whatever
        result.sampled_token = tokenizer.single_token(69)

sampling_settings.post_sampling_hooks = [my_hook]

The function is called for each token right after sampling and passed the sampler output in this format:

@dataclass
class ExLlamaV2PostSamplingResult:

    sampled_token: torch.Tensor | None = None
    sampled_prob: torch.Tensor | None = None
    candidate_tokens: torch.Tensor | None = None
    candidate_probs: torch.Tensor | None = None
    logits: torch.Tensor | None = None

    feed_filters: bool = True

candidate_tokens and candidate_probs will be top the top K tokens and their respective probabilities if you start the stream with return_top_tokens = K.

You can override the sampled_token value, and if you set feed_filters = False the current token won't be considered by filters.

Resampling is a little trickier, but it could be an option, I guess. Just a little extra effort because of the speculative modes.

xonfour commented 3 months ago

Wow, that's great! Already more than I could expect. I hope I didn't interrupt your cleanup session too much with my little feature request... ;-)

Thank you very much!

Resampling might be an interesting option, but not necessary for me right now. Maybe next step. I will continue testing and see how far I can get with this.

xonfour commented 3 months ago

Works like a charm and I have to admit it: I'm already working on resampling.

To go back n tokens it should be something like this, right?

generator.sequence_ids = generator.sequence_ids[:, :-n]
generator.cache.current_seq_len -= n
generator.future_logits = None
generator.future_tokens = None
generator.held_text = ""
generator.held_tokens = generator.no_tokens
generator.held_probs = generator.no_probs
generator.held_ptokens = generator.no_ptokens
generator.held_pprobs = generator.no_pprobs
generator.held_logits = generator.no_logits
turboderp commented 3 months ago

This wouldn't be enough, no. The cache and generator need to be kept in sync, and overall the generator doesn't maintain a history of past states you can easily revert to. Filters are state machines that would need to implement individual rollback functionality, which is outside my control in the case of lm-format-enforcer (unless I start to maintain my own fork.)

The only feasible approach would be to truncate the context and re-evaluate all the filters from the beginning by feeding in however many tokens of the generation-so-far that you wish to keep.

Resampling a single token would just be a matter of calling sample again on the logits to produce one new token, with a few extra steps to update the draft model/n-gram cache.

xonfour commented 3 months ago

Yes, after thinking about it for a while, I have come to the conclusion that this is anything but trivial: Filter rollback is not a realistic option and every single token has to be passed to the filters, so they cannot be held back in some way. Re-evaluating the filter is on the other hand certainly not good for performance.

So _gen_single_token.* is obviously the only place where intervention currently makes sense.

Maybe we could have resample: bool = False in the result class, which if set to True repeats the sample step including calling the hooks again. Optionally we could also have a max_resample_count in the settings to avoid infinite loops.

I used your current implementation to "hack" it in a way where is chooses a space token on uncertainty only if there has already been a space token. Works well in JSON, but needs much more testing. I might also change temperature if the uncertainty remains too high, but there will probably be problems with the dynamic temperature feature.

PS: If I think this further, then it might probably only be possible to make a decision every n tokens whether to continue or revert those n tokens. In addition, the filter would also have to be taken into account (e.g. in the JSON format, any number of spaces can appear outside of elements, which is rather unfavorable inside). Maybe one could conjure up something with batching, some kind of "tree sampling" with reusing batch caches und re-evaluating the filters in the background, but that is far out of scope for now (and probably doesn't make sense at all), I know! ;-)