huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.25k stars 26.34k forks source link

model.generate slower than model forward call #32870

Closed geekifan closed 1 month ago

geekifan commented 1 month ago

System Info

transformers=4.44.0 python=3.11 cuda=12.4

Who can help?

@zucchini-nlp

Information

Tasks

Reproduction

I need to get the hidden states when the model outputs the next token. I compare model.generate with model forward and find that the speed of model forward call is faster than model.generate.

Example: Model: llava next Prompt: \<image>\n Summarize it in one word: model forward call:

inputs = processor(input_prompts, images=batch['image'], return_tensors="pt", padding=True).to(device)
emb = model(**inputs, output_hidden_states=True, return_dict=True).hidden_states[-1][:, -1, :]

model.generate:

inputs = processor(input_prompts, images=batch['image'], return_tensors="pt", padding=True).to(device)
emb = model.generate(**inputs, max_new_tokens=60, output_hidden_states=True, return_dict_in_generate=True).hidden_states[0][-1][:, -1, :]

model forward call can process 1000 samples in 2 minutes but model.generate needs 20 minutes which is 10x slower than model forward call.

Expected behavior

model.generate should have the same speed as model forward call

yonikremer commented 1 month ago

model.generate generates up to 60 tokens when max_new_tokens=60, while model(...) only generates the next 1 token. Setting max_new_tokens=1 will make things about equal.

zucchini-nlp commented 1 month ago

Right, calling generation with only one token should take equal time. Let me know if it doesn't

geekifan commented 1 month ago

@zucchini-nlp @yonikremer Thanks for your reply! Taking max_new_tokens=1 solves the problem. The speed is roughly equal using both methods. Sorry for my misunderstanding of max_new_tokens. I thought it would always generate one word/token if I told it to do so no matter what max_new_tokens is.