huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.67k stars 26.93k forks source link

batch inference scales linearly with batch size when input is long #33411

Closed platypus1989 closed 3 weeks ago

platypus1989 commented 1 month ago

System Info

transformers.version=4.42.4

Who can help?

@Gante

Information

Tasks

Reproduction

Hi, I am noticing when running batch inference over Mixtral-8x7B-Instruct-v0.1, model seems to be scale nicely (sublinearly) with batch size if input size is small, but when input size gets large (more than 400 tokens), inference time start to become linearly against batch size.

Some sample code to reproduce what I am seeing

import torch
from time import time
import pandas as pd
model_id = 'mistralai/Mixtral-8x7B-Instruct-v0.1'

model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map='auto',
    )
tokenizer = AutoTokenizer.from_pretrained(model_id)

def inference_time(input_size, batch_size):
    prompt = "how are you? "*input_size
    prompts = [prompt]*batch_size
    input = tokenizer(prompts)
    input_ids, attention_mask = torch.tensor(input['input_ids']).to(model.device), torch.tensor(input['attention_mask']).to(model.device)
    tic = time()
    with torch.no_grad():
        output = model(input_ids=input_ids, attention_mask=attention_mask)
    return time() - tic

input_sizes = []
batch_sizes = []
wall_time = []
for i in [1, 5, 10, 20, 50, 100]:
    for j in [1, 5, 10, 20, 50, 100]:
        input_sizes.append(i)
        batch_sizes.append(j)
        wall_time.append(inference_time(i, j))

pd.DataFrame({
    'input_size': input_sizes,
    'batch_size': batch_sizes,
    'inference_time': wall_time,
})

You can see from the output, the inference time scales sublinearly against batch size when input size is less than 10. Once the input size increased to more than 20, the inference time starts to scale with batch size somewhat linearly.

Screenshot 2024-09-10 at 12 29 11 AM

Expected behavior

I was expecting batch inference's running time to be scaling sublinearly regardless of the input size to some extent.

platypus1989 commented 1 month ago

Btw I am running the experiment on an instance with 4 A100.

github-actions[bot] commented 4 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.