huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
128.5k stars 25.49k forks source link

Performance degradation with BF16 precision #27994

Closed jerin-scalers-ai closed 5 months ago

jerin-scalers-ai commented 6 months ago

System Info

Transformers: 4.35.2 Torch: 2.1.1-cpu CPU: Intel Xeon 4th Gen processor

Who can help?

@ArthurZucker Hi, I was comparing performance of Llama 2 7b chat hf model with different precisions. I observed that there is a significant degrade on performance (inference time) with bfloat16 compared to fp32 model in Intel CPU . Bf16 is suppose to give better performance than fp32 . Please refer below table for details:

Precision Tokens Generated Infer time (sec)
FP32 186 12.51
BF16 186 115.37

Information

Tasks

Reproduction

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import time

model_id = "meta-llama/Llama-2-7b-chat-hf"
device = "cpu"
torch_dtype = torch.bfloat16
tokenizer = AutoTokenizer.from_pretrained(model_id)
input_text = "In maximum 180 words, explain why purchasing Dell Poweredge servers offer much better TCO to enterprises compared to using public cloud infrastructure, for AI initiatives"

text_generator = pipeline(
    "text-generation",
    model=model_id,
    tokenizer=tokenizer,
    return_tensors=True,
    device=device,
    torch_dtype = torch_dtype,
)

for _ in range(5):
    s_time = time.time()
    # Inference benchmarking
    output = text_generator(
        input_text,
        max_new_tokens=256,
        temperature=1,
    )
    e_time = time.time()
    # print(output)
    print(tokenizer.decode(output[0]["generated_token_ids"]))
    num_tokens = len(output[0]["generated_token_ids"])
    print(f"Num tokens: {num_tokens}")
    print(f"Infer time: {e_time-s_time}")

Expected behavior

Bf16 is suppose to give better performance than fp32

ArthurZucker commented 6 months ago

The inference time in bfloat16 depends on the hardware you are using and the pytorch version as well. Recommending you to use float16 for inference.

github-actions[bot] commented 5 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.