huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.39k stars 27.09k forks source link

Static cache is locked after torch.compile with model.generate #30351

Closed mobicham closed 6 months ago

mobicham commented 7 months ago

System Info

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

When using torch.compile(model.forward) with static cache, the cache seems to be locked with the first prompt that was used for the compilation time. I re-implemented the generate logic and the same issue happens, so it's not just a bug with model.generate. This happens with older and newer versions of transformers.

Here's a code snippet:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id  = "meta-llama/Llama-2-7b-chat-hf"
model     = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, attn_implementation="sdpa").cuda().eval();
tokenizer = AutoTokenizer.from_pretrained(model_id) 
tokenizer.add_bos_token = False

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

inputs = tokenizer(["<s> [INST] Write an essay about large language models [/INST]"], return_tensors="pt").to(model.device)
for _ in range(3):
    gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None)
print(tokenizer.decode(gen_out[0]))

# Output: OK
#  <s>  [INST] Write an essay about large language models [/INST]   Large language models have revolutionized the field of natural language processing in recent years. 
# These models are trained on vast amounts of text data and are capable of generating text, classifying text, and answering questions with remarkable accuracy. 
# In this essay, we will explore the current state of large language models, their potential applications, and the challenges and limitations that come with their use.....

inputs = tokenizer(["<s> [INST] How to make a chocolate cake? [/INST]"], return_tensors="pt").to(model.device)
gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None)
print(tokenizer.decode(gen_out[0]))

# Output: WRONG still talks about the previous prompt.
# <s>  [INST] How to make a chocolate cake? [/INST]  ge language models (LLMs) are a class of artificial intelligence (AI) models that have gained significant 
#attention in recent years due to their impressive language processing capabilities. Here, we will explore the concept of LLMs, their applications, 
# and their potential impact on various fields.
# What are Large Language Models?
# LLMs are neural network-based models that are trained on vast amounts of text data to generate language outputs that are coherent and natural

Expected behavior

The output should correspond to the input prompt, not the prompt the model was first compiled with.

Thank you!

amyeroberts commented 7 months ago

cc @gante

ArthurZucker commented 7 months ago

Super weird, and I can indeed reproduce. The fix is use_cache=False. It's counter intuitive, but this will work:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id  = "meta-llama/Llama-2-7b-chat-hf"
model     = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, attn_implementation="sdpa").cuda().eval();
tokenizer = AutoTokenizer.from_pretrained(model_id) 
tokenizer.add_bos_token = False

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

inputs = tokenizer(["<s> [INST] Write an essay about large language models [/INST]"], return_tensors="pt").to(model.device)
for _ in range(3):
    gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)
print(tokenizer.decode(gen_out[0]))

inputs = tokenizer(["<s> [INST] How to make a chocolate cake? [/INST]"], return_tensors="pt").to(model.device)
gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)
print(tokenizer.decode(gen_out[0]))
gante commented 7 months ago

When trying the original script with torch==2.4.0.dev20240418+cu121 I get a Aborted (core dumped) preceded by RuntimeError: Triton Error [CUDA]: device-side assert triggered and a bunch of out of bounds memory access 👀

@ArthurZucker's suggested script gets the same exceptions (because the first calls hit the same issue)

ArthurZucker commented 7 months ago

Ah! I did not get that and successfully generated, no idea what went wrong with yours

ArthurZucker commented 7 months ago
image

that''s what I got and it was pretty fast

mobicham commented 7 months ago

Thanks @ArthurZucker ! Is there a way to replicate this use_cache=False behavior but when manually writing the generate function, like this one (based on your code): https://github.com/mobiusml/hqq/blob/master/hqq/utils/generation_hf.py

The reason is because it's better to compile the decode_one_token function instead of the whole forward pass, to avoid annoying compilation everytime the input prompt shape changes.

I guess here pass use_cache=False? https://github.com/mobiusml/hqq/blob/master/hqq/utils/generation_hf.py#L72

@gante I get that sometimes as well. I think it's a bit better with torch nightly build.

mobicham commented 7 months ago

@ArthurZucker use_cache=False is not really a solution, the speed is much slower vs. use_cache=True. I was not able to make it work properly by setting use_cache=False directly in the model forward pass either.

@gante that cuda issue mainly happens when you compile the whole forward pass, normally you only need to compile the forward pass for the decoding part only (input is 1 token and fixed), not the prefill.

ArthurZucker commented 7 months ago

@mobicham, normally you should not have this issue with the script that compiles decode_one_token. I pushed a fix to main that should have solved this: #30380, which was probably not overwriting the cache. I think reset_cache might not work as expected

mobicham commented 7 months ago

@ArthurZucker thanks! I found a hack: warm-up with use_cache=False the very first time you compile, then use_cache=True for generation. It still needs to warm-up again with use_cache=True but at least the output is correct.

Update: the warm-up with the full torch.compile takes a lot of VRAM. The best would be to make it work with decode_one_token. Still haven't found a proper way of doing it.

There's another problem: if you compile using max_new_tokens=100 for example and use max_new_tokens=1000 after the warm-up, you get RuntimeError: CUDA error: device-side assert triggered. The trick is to use a larger max_new_tokens at compilation time, then it works with any value less than that.

model.forward = torch.compile(model.forward, **{"mode":"reduce-overhead", "fullgraph":True})

prompt = "Write an essay about large language models."

# warm-up
for _ in range(10):
    gen_out = model.generate(**tokenize_prompt(prompt), do_sample=False, cache_implementation="static", max_new_tokens=1000, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)

prompt = "How do I make a cake"
import time
t1 = time.time()
gen_out = model.generate(**tokenize_prompt(prompt), do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=True)
t2 = time.time()
print(len(gen_out[0])/(t2-t1), "tokens/sec")
mobicham commented 7 months ago

Was not able to test the fix because there's another problem with 4.41.0: https://github.com/huggingface/transformers/issues/30417

ArthurZucker commented 7 months ago

Super weird and we'll fix it asap

ArthurZucker commented 7 months ago

Might be related to #30414 as well

mobicham commented 7 months ago

I was finally able to make it work without blowing up the VRAM:

  1. Compile with inputs of size [batch_size, 1]: https://github.com/mobiusml/hqq/blob/master/hqq/utils/generation_hf.py#L57-L72
  2. Warm up with 3 prompts with use_cache=False

With this approach, a 4-bit Llama2-7B takes ~5.6GB of runtime with a max 1024 cache size. If I try the same with model.generate() I run out of VRAM after the 2nd or 3-rd warm-up prompt.

The only issue is the speed. With the approach above I get 165 tokens/sec, it should to be ~205 tokens/sec.

Update: the speed depends on the size of the initialized cache for some reason.

Update 2: It is actually not fixed, the outputs still mix some outputs from previous results. Will try the fix as soon as a https://github.com/huggingface/transformers/issues/30417 is fixed

ArthurZucker commented 7 months ago

Wow thanks a lot for all this valuable debugging, would really love to fix this!

mobicham commented 7 months ago

Thanks @ArthurZucker I spent the whole day playing with this, the latest version is here . Here's what I noticed so far:

ArthurZucker commented 7 months ago

BTW we are gonna move with #30476

mobicham commented 7 months ago

Thank you for the update!