Closed mobicham closed 6 months ago
cc @gante
Super weird, and I can indeed reproduce.
The fix is use_cache=False
. It's counter intuitive, but this will work:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, attn_implementation="sdpa").cuda().eval();
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.add_bos_token = False
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
inputs = tokenizer(["<s> [INST] Write an essay about large language models [/INST]"], return_tensors="pt").to(model.device)
for _ in range(3):
gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)
print(tokenizer.decode(gen_out[0]))
inputs = tokenizer(["<s> [INST] How to make a chocolate cake? [/INST]"], return_tensors="pt").to(model.device)
gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)
print(tokenizer.decode(gen_out[0]))
When trying the original script with torch==2.4.0.dev20240418+cu121
I get a Aborted (core dumped)
preceded by RuntimeError: Triton Error [CUDA]: device-side assert triggered
and a bunch of out of bounds memory access 👀
@ArthurZucker's suggested script gets the same exceptions (because the first calls hit the same issue)
Ah! I did not get that and successfully generated, no idea what went wrong with yours
that''s what I got and it was pretty fast
Thanks @ArthurZucker !
Is there a way to replicate this use_cache=False
behavior but when manually writing the generate function, like this one (based on your code): https://github.com/mobiusml/hqq/blob/master/hqq/utils/generation_hf.py
The reason is because it's better to compile the decode_one_token
function instead of the whole forward pass, to avoid annoying compilation everytime the input prompt shape changes.
I guess here pass use_cache=False
? https://github.com/mobiusml/hqq/blob/master/hqq/utils/generation_hf.py#L72
@gante I get that sometimes as well. I think it's a bit better with torch nightly build.
@ArthurZucker use_cache=False
is not really a solution, the speed is much slower vs. use_cache=True
.
I was not able to make it work properly by setting use_cache=False
directly in the model forward pass either.
@gante that cuda issue mainly happens when you compile the whole forward pass, normally you only need to compile the forward pass for the decoding part only (input is 1 token and fixed), not the prefill.
@mobicham, normally you should not have this issue with the script that compiles decode_one_token
. I pushed a fix to main that should have solved this: #30380, which was probably not overwriting the cache.
I think reset_cache
might not work as expected
@ArthurZucker thanks! I found a hack: warm-up with use_cache=False
the very first time you compile, then use_cache=True
for generation. It still needs to warm-up again with use_cache=True
but at least the output is correct.
Update: the warm-up with the full torch.compile takes a lot of VRAM. The best would be to make it work with decode_one_token
. Still haven't found a proper way of doing it.
There's another problem: if you compile using max_new_tokens=100
for example and use max_new_tokens=1000
after the warm-up, you get RuntimeError: CUDA error: device-side assert triggered
. The trick is to use a larger max_new_tokens
at compilation time, then it works with any value less than that.
model.forward = torch.compile(model.forward, **{"mode":"reduce-overhead", "fullgraph":True})
prompt = "Write an essay about large language models."
# warm-up
for _ in range(10):
gen_out = model.generate(**tokenize_prompt(prompt), do_sample=False, cache_implementation="static", max_new_tokens=1000, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)
prompt = "How do I make a cake"
import time
t1 = time.time()
gen_out = model.generate(**tokenize_prompt(prompt), do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=True)
t2 = time.time()
print(len(gen_out[0])/(t2-t1), "tokens/sec")
Was not able to test the fix because there's another problem with 4.41.0: https://github.com/huggingface/transformers/issues/30417
Super weird and we'll fix it asap
Might be related to #30414 as well
I was finally able to make it work without blowing up the VRAM:
use_cache=False
With this approach, a 4-bit Llama2-7B takes ~5.6GB of runtime with a max 1024 cache size.
If I try the same with model.generate()
I run out of VRAM after the 2nd or 3-rd warm-up prompt.
The only issue is the speed. With the approach above I get 165 tokens/sec, it should to be ~205 tokens/sec.
Update: the speed depends on the size of the initialized cache for some reason.
Update 2: It is actually not fixed, the outputs still mix some outputs from previous results. Will try the fix as soon as a https://github.com/huggingface/transformers/issues/30417 is fixed
Wow thanks a lot for all this valuable debugging, would really love to fix this!
Thanks @ArthurZucker I spent the whole day playing with this, the latest version is here . Here's what I noticed so far:
RuntimeError: CUDA error: device-side assert triggered
BTW we are gonna move with #30476
Thank you for the update!
System Info
transformers
version: 4.39.0.dev0Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
When using
torch.compile(model.forward)
with static cache, the cache seems to be locked with the first prompt that was used for the compilation time. I re-implemented the generate logic and the same issue happens, so it's not just a bug withmodel.generate
. This happens with older and newer versions of transformers.Here's a code snippet:
Expected behavior
The output should correspond to the input prompt, not the prompt the model was first compiled with.
Thank you!