rustformers / llm

[Unmaintained, see README] An ecosystem of Rust libraries for working with large language models
https://docs.rs/llm/latest/llm/
Apache License 2.0
6.06k stars 350 forks source link

Proper Rewind+Refeed when stop token is detected. #407

Open JuliaMerz opened 10 months ago

JuliaMerz commented 10 months ago

See this segment of code

/// An [InferenceResponse] callback that will halt inference when a `stop_sequence` is generated.
/// This callback is used in [InferenceSession::infer] in chat_mode.
pub fn conversation_inference_callback<'a, E: std::error::Error + Send + Sync + 'static>(
    stop_sequence: &'a str,
    mut callback: impl FnMut(String) + 'a,
) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E> + 'a {
    let mut stop_sequence_buf = String::new();
    move |resp| match resp {
        InferenceResponse::InferredToken(token) => {
            // We've generated a token, so we need to check if it's contained in the stop sequence.
            let mut buf = stop_sequence_buf.clone();
            buf.push_str(&token);

            if buf.starts_with(stop_sequence) {
                // We've generated the stop sequence, so we're done.
                // Note that this will contain the extra tokens that were generated after the stop sequence,
                // which may affect generation. This is non-ideal, but it's the best we can do without
                // modifying the model.
                stop_sequence_buf.clear();
                return Ok(InferenceFeedback::Halt);
            } else if stop_sequence.starts_with(&buf) {
                // We've generated a prefix of the stop sequence, so we need to keep buffering.
                stop_sequence_buf = buf;
                return Ok(InferenceFeedback::Continue);
            }

            // We've generated a token that isn't part of the stop sequence, so we can
            // pass it to the callback.
            stop_sequence_buf.clear();
            callback(buf);
            Ok(InferenceFeedback::Continue)
        }
        InferenceResponse::EotToken => Ok(InferenceFeedback::Halt),
        _ => Ok(InferenceFeedback::Continue),
    }
}

At the moment we don't catch every possible case of stop tokens, we catch some of them. Ideally we'd catch all cases, then rewind in order to bring the model into the state right after the stop token.

LLukas22 commented 10 months ago

Maybe we could change the callback to work with the actual tokens instead of the decoded string, that should make detecting the correct stop sequence simpler or is there a specific reason we need to work with strings here?

What cases of stop tokens aren't detected by the current implementation, could you provide an example? Rewinding the session should be simple, except for cases where there is some token merging happening in the tokenizer 🤔