huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.74k stars 26.23k forks source link

HybridCache slow after reset #32313

Closed sanchit-gandhi closed 1 month ago

sanchit-gandhi commented 1 month ago

System Info

Who can help?

@sanchit-gandhi @gante @ArthurZucker

Information

Tasks

Reproduction

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import time
from transformers import AutoTokenizer, Gemma2ForCausalLM
from transformers.cache_utils import HybridCache
import torch

torch.set_float32_matmul_precision("high")
# catch re-compilations
torch._logging.set_logs(graph_breaks=True, recompiles=True)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
model = Gemma2ForCausalLM.from_pretrained("google/gemma-2-9b", attn_implementation="eager")
model.to("cuda")

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

input_text = "The theory of special relativity states "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
prompt_length = input_ids.input_ids.shape[1]

model.generation_config.min_new_tokens = model.generation_config.max_new_tokens = 32

past_key_values = HybridCache(
    config=model.config,
    max_batch_size=1,
    max_cache_len=prompt_length + 4 * model.generation_config.max_new_tokens,
    device=model.device,
    dtype=model.dtype
)

# enable passing kv cache
model._supports_cache_class = True
model.generation_config.cache_implementation = None

for i in range(3):
    # two warm-ups
    outputs_1 = model.generate(**input_ids, past_key_values=past_key_values, do_sample=True, temperature=1)
    outputs_2 = model.generate(outputs_1, past_key_values=past_key_values, do_sample=True, temperature=1)

    # one timed run
    torch.cuda.synchronize("cuda")
    start = time.time()
    outputs_3 = model.generate(outputs_2, past_key_values=past_key_values, do_sample=True, temperature=1)
    torch.cuda.synchronize("cuda")
    runtime = time.time() - start
    print(f"Run {i}: {model.generation_config.max_new_tokens / runtime} tok/s")

    past_key_values.reset()

Print Output:

V0730 05:56:47.345000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles] Recompiling function forward in /home/sanchit/transformers/src/transformers/models/gemma2/modeling_gemma2.py:891
V0730 05:56:47.345000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles]     triggered by the following guard failure(s):
V0730 05:56:47.345000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles]     - tensor 'L['input_ids']' stride mismatch at index 0. expected 8, actual 1

bdV0730 05:57:41.549000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles] Recompiling function forward in /home/sanchit/transformers/src/transformers/models/gemma2/modeling_gemma2.py:891
V0730 05:57:41.549000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles]     triggered by the following guard failure(s):
V0730 05:57:41.549000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles]     - tensor 'L['input_ids']' stride mismatch at index 0. expected 1, actual 40
V0730 05:57:41.549000 140167293411712 torch/_dynamo/guards.py:1425] [__recompiles]     - tensor 'L['input_ids']' stride mismatch at index 0. expected 8, actual 40

Run 0: 28.8159080589 tok/s
Run 1: 0.878302057247666 tok/s
Run 2: 19.946942197324718 tok/s

=> we get only two recompilations (expected), but the inference speed of the second and third runs are significantly lower than the first. This pattern happens only after calling past_key_values.reset(), which suggests a bug in how we're resetting the HybridCache.

Expected behavior

Run 0: 28.8159080589 tok/s
Run 1: 28.8159080589 tok/s
Run 2: 28.8159080589 tok/s
sanchit-gandhi commented 1 month ago

Note @fxmarty that this issue also occurs when we only pass the input_ids to the model (and not the attention mask)

fxmarty commented 1 month ago

@sanchit-gandhi Interesting, is this greedy search? With llama greedy search input_ids stride is always the same, might be safer to call contiguous/clone anyway

sanchit-gandhi commented 1 month ago

It's sampling (we set do_sample=True, temperature=1). Having played around with your PR, it looks like it's the same issue that's affecting Gemma-2 as LLaMA, so I've pushed the changes for Gemma/Gemma-2 directly to your PR!