Closed sanchit-gandhi closed 1 month ago
Note @fxmarty that this issue also occurs when we only pass the input_ids
to the model (and not the attention mask)
@sanchit-gandhi Interesting, is this greedy search? With llama greedy search input_ids
stride is always the same, might be safer to call contiguous/clone anyway
It's sampling (we set do_sample=True, temperature=1
). Having played around with your PR, it looks like it's the same issue that's affecting Gemma-2 as LLaMA, so I've pushed the changes for Gemma/Gemma-2 directly to your PR!
System Info
transformers
version: 4.44.0.dev0Who can help?
@sanchit-gandhi @gante @ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Print Output:
=> we get only two recompilations (expected), but the inference speed of the second and third runs are significantly lower than the first. This pattern happens only after calling
past_key_values.reset()
, which suggests a bug in how we're resetting theHybridCache
.Expected behavior