Open segeljakt opened 1 month ago
Oh, I managed to get it working by turning off kv-caching and then turning it on again:
impl Chat {
// ...
pub fn run(&mut self, prompt: &str) -> Result<String> {
self.tokens.extend(
self.tokenizer
.encode(prompt, ADD_SPECIAL_TOKENS)
.map_err(Error::msg)?
.get_ids(),
);
self.cache.use_kv_cache = false; // <---- Here
for _ in 0..SAMPLE_LEN {
let tokens_slice = &self.tokens[self.index..];
let input = Tensor::new(tokens_slice, &self.device)?.unsqueeze(0)?;
let logits = self
.model
.forward(&input, self.index, &mut self.cache)?
.squeeze(0)?;
self.cache.use_kv_cache = true; // <---- Here
let logits = candle_transformers::utils::apply_repeat_penalty(
&logits,
REPEAT_PENALTY,
&self.tokens[self.tokens.len().saturating_sub(REPEAT_LAST_N)..],
)?;
// ...
}
// ...
}
// ...
}
This means the first forward of every run is done without kv-caching. Is this the correct way to approach it?
I tried to modify the code in https://github.com/huggingface/candle/blob/main/candle-examples/examples/llama/main.rs to become a chatbot where each new prompt considers the history of all previous prompts. This is my code:
When I run, I get this error:
The error happens the second time I call
Chat::run
inmain
and is thrown from this statement.The first time I run the chat in main, the shape of
input
is[1,5]
. After producing an output token, the next shape ofinput
is[1,1]
since I use key-value caching.When I later enter a new prompt and run the chat, the
input
shape is[1,3]
(which includes the EOS token from the previous run). The error disappears if drop some tokens so the shape becomes[1,1]
. Is there something that says the shape must be[1,1]
when we use key-value caching?