turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.2k stars 236 forks source link

Refactor token healing initialization. #330

Open bjj opened 4 months ago

bjj commented 4 months ago

begin_steam now leaves the stream in a state where the first call to stream will generated the healed token without needing a special case

This is just the first step in making it possible to feed the streaming generator logits (from batch processing) rather than having it pull logits from the model. The heal_token case made the number of model.forward calls from stream variable, and this change makes it constant.

bjj commented 4 months ago

Added a second patch with .append_logits()

turboderp commented 4 months ago

I don't understand why you'd want to send logits to the generator? The whole point of the generator is to pull logits from the model until a stop condition is met. If you just want to sample from logits you've produced with model.forward(), why not call the sampler directly?

Also, the rationale for not making the token healing a separate iteration of stream() is to make sure that every call to stream() uses exactly one token of the available context length.

bjj commented 4 months ago

If you just want to sample from logits you've produced with model.forward(), why not call the sampler directly?

The sampler isn't really the problem. It's all the token decode logic in the streaming generator that I want to re-use. I did start out calling sample/decode, and then I started running into all the issues with that (sentencepiece, partial UTF8 strings, multi-token stop strings, etc).

The whole point of the generator is to pull logits from the model until a stop condition is met.

But "pull logits from the model" is about one line out of several hundred (not counting speculative generation). I wanted to refactor the whole class so that the stream-decode part could be re-used. The best place to start seemed to be moving toward a model where init leaves the state in a "ready for next model.forward" and stream could have a variant like stream(logits, ...)

Also, the rationale for not making the token healing a separate iteration of stream() is to make sure that every call to stream() uses exactly one token of the available context length.

If that's important I think it can be restored easily by having an explicit stream(logits, ... variant (without that guarantee) and invoking _stream() twice on startup in the healing case for regular callers.

bjj commented 4 months ago

Do you have any further feedback on this PR? Can you spell out how you envision a server that supports continuous batching and streaming would use the facilities in exllamav2?