Open Flipbookee opened 1 month ago
Every time we backtrack, we sample an alternate token using the adjusted probs that we stored. Then we continue inference by calling model.generate again. So the KV cache would (I assume) be reset from that point.
This isn't exactly optimised but it's functional at least, and it uses the cache from the point model.generate is called up until it hits a stopping criteria.
I'm not sure if the KV cache can do checkpointing / rewinding so that we can optimise this part better. Haven't dug this deep into transformers inference before. If you have some ideas on how to do this, please let me know.
I'm not sure if the KV cache can do checkpointing / rewinding so that we can optimise this part better.
If the kv cache is of type DynamicCache, you can crop it, yes. I'm doing this here. Please note that I'm passing a negative number (as in "remove lastest X") rather than the supposedly new length that it should have (I also tried that and it ended in error. My assumption is that it's related to the bos token, but I haven't checked.)
I'm not sure if the KV cache can do checkpointing / rewinding so that we can optimise this part better.
If the kv cache is of type DynamicCache, you can crop it, yes. I'm doing this here. Please note that I'm passing a negative number (as in "remove lastest X") rather than the supposedly new length that it should have (I also tried that and it ended in error. My assumption is that it's related to the bos token, but I haven't checked.)
Thanks for the tip Mihai. I'll give this a try.
That's a great tip, Mihaiii! I didn't know that can be done like that.
There's another potential issue related to backtracking—positional encoding for subsequent tokens will be slightly off after backtracking by more than one token. I don't know how much that might affect the output, but I'd guess it would depend on the model and its positional encoding method, possibly the number of backtracked tokens, and the type of output being generated (structured as in code or unstructured text).
Adjusting the position would be a single line in code, I'd imagine, but I don't know if there's an easy way to do that.
If I understood well the backtracking process, it only relies on restoring the (adjusted) pre-softmax logits for the first token in the slop sequence, which will work for generating an alternative token at that position in the sequence. But how does backtracking work for the next token after that if we don’t restore the KV cache as well, as for sequences longer than one token the KV cache would have advanced already?