huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.93k stars 26.99k forks source link

CPU processing is extremely slow for models loaded with `torch_dtype = torch.float16` #34692

Open blincoln-bf opened 2 days ago

blincoln-bf commented 2 days ago

System Info

Transformers versions: 4.44.2, 4.46.2 PyTorch versions: 2.4.0, 2.5.1 Python version: 3.11.2

Who can help?

No response

Information

Tasks

Reproduction

While troubleshooting very weird behaviour with GPT-NeoX when processed on a CPU device, I discovered that Transformers will load a model with torch_dtype = torch.float16 and process it on the CPU without any apparent warning or other message, but its performance is very slow compared to float32 (5+ times slower) or bfloat16 (10+ times slower). Given the amount of documentation online that suggests using .half() or torch_dtype = torch.float16 to conserve memory, can I suggest adding a warning message when a model loaded in this format is processed on a CPU device? I know float16 support for CPU at all is relatively new, but given the lack of information anywhere about the massive performance hit it currently incurs, I assumed CPU processing as a whole in Transformers was essentially unusable for real-world work (especially training / gradient operations). In reality, CPU processing is surprisingly fast when set to float32 or bfloat16 format.

Here's a quick benchmark script based on the example usage for GPT-NeoX that loads the model, then generates text using three prompts. It performs this test for torch_dtype = None, torch.float32, torch.float16 and , torch.bfloat16:

#!/bin/env python

# https://huggingface.co/docs/transformers/en/model_doc/gpt_neox#usage-example
# But modified to debug an issue

import datetime
import gc
import json
import torch

from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
from transformers.generation import GenerationConfig

def output_generation_step_info(prompt_num, generation_start_dt, generation_previous_dt, step_description):
    generation_current_dt = datetime.datetime.now()
    generation_current_dt_string = generation_current_dt.replace(microsecond = 0).isoformat()
    generation_elapsed = generation_current_dt - generation_start_dt
    generation_elapsed_string = f"{generation_elapsed.total_seconds()}"
    step_elapsed = generation_current_dt - generation_previous_dt
    step_elapsed_string = f"{generation_elapsed.total_seconds()}"
    print(f"\tReached prompt {prompt_num} stage: '{step_description}' at  {generation_current_dt_string}. {step_elapsed_string} seconds elapsed between the previous stage and this stage. {generation_elapsed_string} seconds elapsed for prompt {prompt_num} in total.")
    return generation_current_dt

def generate_text(model, tokenizer, prompt, prompt_num):
    generation_start_dt = datetime.datetime.now()
    generation_start_dt_string = generation_start_dt.replace(microsecond = 0).isoformat()
    generation_previous_dt = generation_start_dt
    print(f"Beginning processing prompt {prompt_num} at {generation_start_dt_string}")

    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

    generation_previous_dt = output_generation_step_info(prompt_num, generation_start_dt, generation_previous_dt, "Finished tokenizing input. Creating attention mask.")

    attn_masks = torch.ones_like(input_ids)

    generation_previous_dt = output_generation_step_info(prompt_num, generation_start_dt, generation_previous_dt, "Finished creating attention mask. Generating output.")

    existing_config_dict = model.generation_config.to_dict()

    gen_config = GenerationConfig.from_dict(config_dict = existing_config_dict)

    gen_config.do_sample = False
    gen_config.temperature = 1.0
    gen_config.max_length = 80

    gen_tokens = model.generate(
        input_ids,
        attention_mask = attn_masks,   
        generation_config = gen_config
    )

    generation_previous_dt = output_generation_step_info(prompt_num, generation_start_dt, generation_previous_dt, "Finished generating output. Decoding token IDs to text.")

    gen_text = tokenizer.batch_decode(gen_tokens)[0]

    generation_previous_dt = output_generation_step_info(prompt_num, generation_start_dt, generation_previous_dt, "Finished decoding token IDs to text.")

    print(f"Model input: '{prompt}'\nModel output: '{gen_text}'")

    generation_end_dt = datetime.datetime.now()
    generation_end_dt_string = generation_end_dt.replace(microsecond = 0).isoformat()
    generation_elapsed = generation_end_dt - generation_start_dt
    generation_elapsed_string = f"{generation_elapsed.total_seconds()}"
    print(f"Finished processing prompt {prompt_num} at {generation_end_dt_string} - {generation_elapsed_string} seconds elapsed in total for this prompt.")

def main():
    model_path = "EleutherAI/gpt-neox-20b"
    tokenizer_path = model_path

    prompts = []
    # increase this range to test for behaviour that changes with successive operations
    for i in range(0, 1):
        prompts.append("<|prompter|>Please tell me about GPTNeoX20B, the 20B-parameter autoregressive Transformer model developed by EleutherAI.<|endoftext|><|assistant|>")
        prompts.append("<|prompter|>Please tell me about GPTNeoX20B, the 20B-parameter autoregressive Transformer model developed by EleutherAI.<|endoftext|><|assistant|>Sure, I'd be happy to tell you about GPTNeoX20B")
        prompts.append("GPTNeoX20B is a 20B-parameter autoregressive Transformer model developed by EleutherAI.")

    dtype_list = [ None, torch.float32, torch.float16, torch.bfloat16 ]

    for current_dtype in dtype_list:
        start_dt = datetime.datetime.now()
        start_dt_string = start_dt.replace(microsecond = 0).isoformat()

        print(f"Starting test with torch_dtype = {current_dtype} at {start_dt_string}")

        model = GPTNeoXForCausalLM.from_pretrained(
            model_path,
            torch_dtype = current_dtype
            ).to(device = "cpu").eval()

        print(f"model.dtype (when loaded with torch_dtype {current_dtype}) = {model.dtype}")

        tokenizer = GPTNeoXTokenizerFast.from_pretrained(tokenizer_path)

        for prompt_num in range(0, len(prompts)):
           generate_text(model, tokenizer, prompts[prompt_num], prompt_num)

        end_dt = datetime.datetime.now()
        end_dt_string = end_dt.replace(microsecond = 0).isoformat()
        elapsed = end_dt - start_dt
        elapsed_string = f"{elapsed.total_seconds()}"

        print(f"Finished test with torch_dtype = {current_dtype} at {end_dt_string} - {elapsed_string} seconds elapsed for the entire test cycle using this dtype.")

        del model
        gc.collect()

main()

Excerpt of the output with just the relevant statistics:

Finished processing prompt 0 at 2024-11-11T13:09:31 - 88.848283 seconds elapsed in total for this prompt.
Finished processing prompt 1 at 2024-11-11T13:10:20 - 49.434201 seconds elapsed in total for this prompt.
Finished processing prompt 2 at 2024-11-11T13:12:15 - 114.557077 seconds elapsed in total for this prompt.
Finished test with torch_dtype = None at 2024-11-11T13:12:15 - 260.500114 seconds elapsed for the entire test cycle using this dtype.
Finished processing prompt 0 at 2024-11-11T13:13:42 - 79.813121 seconds elapsed in total for this prompt.
Finished processing prompt 1 at 2024-11-11T13:14:31 - 48.493672 seconds elapsed in total for this prompt.
Finished processing prompt 2 at 2024-11-11T13:16:21 - 109.835365 seconds elapsed in total for this prompt.
Finished test with torch_dtype = torch.float32 at 2024-11-11T13:16:21 - 245.690124 seconds elapsed for the entire test cycle using this dtype.
Finished processing prompt 0 at 2024-11-11T13:25:36 - 553.680202 seconds elapsed in total for this prompt.
Finished processing prompt 1 at 2024-11-11T13:34:44 - 548.825131 seconds elapsed in total for this prompt.
Finished processing prompt 2 at 2024-11-11T13:44:00 - 555.500857 seconds elapsed in total for this prompt.
Finished test with torch_dtype = torch.float16 at 2024-11-11T13:44:00 - 1659.152628 seconds elapsed for the entire test cycle using this dtype.
Finished processing prompt 0 at 2024-11-11T13:44:44 - 37.687044 seconds elapsed in total for this prompt.
Finished processing prompt 1 at 2024-11-11T13:45:05 - 21.75417 seconds elapsed in total for this prompt.
Finished processing prompt 2 at 2024-11-11T13:45:58 - 52.408085 seconds elapsed in total for this prompt.
Finished test with torch_dtype = torch.bfloat16 at 2024-11-11T13:45:58 - 117.512555 seconds elapsed for the entire test cycle using this dtype.

As you can see, float16 performance scored about 7 times worse than float32 for this run, and about 14 times worse than bfloat16, with simple text generation taking almost ten minutes in float16 format. For training/gradient operations, the effect is even more of a problem. Operations that take a few minutes in the other formats can take hours in float16 format (in the case of the GPT-NeoX issue, 10+ hours for a call to forward). I don't have a good minimal test case for that, though.

This is not limited to GPT-NeoX. For example, here's the same script, but modified to use Phi-3-mini-128k instead:

#!/bin/env python

# https://huggingface.co/docs/transformers/en/model_doc/gpt_neox#usage-example
# But modified to debug an issue

import datetime
import gc
import json
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig

def output_generation_step_info(prompt_num, generation_start_dt, generation_previous_dt, step_description):
    generation_current_dt = datetime.datetime.now()
    generation_current_dt_string = generation_current_dt.replace(microsecond = 0).isoformat()
    generation_elapsed = generation_current_dt - generation_start_dt
    generation_elapsed_string = f"{generation_elapsed.total_seconds()}"
    step_elapsed = generation_current_dt - generation_previous_dt
    step_elapsed_string = f"{generation_elapsed.total_seconds()}"
    print(f"\tReached prompt {prompt_num} stage: '{step_description}' at  {generation_current_dt_string}. {step_elapsed_string} seconds elapsed between the previous stage and this stage. {generation_elapsed_string} seconds elapsed for prompt {prompt_num} in total.")
    return generation_current_dt

def generate_text(model, tokenizer, prompt, prompt_num):
    generation_start_dt = datetime.datetime.now()
    generation_start_dt_string = generation_start_dt.replace(microsecond = 0).isoformat()
    generation_previous_dt = generation_start_dt
    print(f"Beginning processing prompt {prompt_num} at {generation_start_dt_string}")

    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

    generation_previous_dt = output_generation_step_info(prompt_num, generation_start_dt, generation_previous_dt, "Finished tokenizing input. Creating attention mask.")

    attn_masks = torch.ones_like(input_ids)

    generation_previous_dt = output_generation_step_info(prompt_num, generation_start_dt, generation_previous_dt, "Finished creating attention mask. Generating output.")

    existing_config_dict = model.generation_config.to_dict()

    gen_config = GenerationConfig.from_dict(config_dict = existing_config_dict)

    gen_config.do_sample = False
    gen_config.temperature = 1.0
    gen_config.max_length = 80

    gen_tokens = model.generate(
        input_ids,
        attention_mask = attn_masks,   
        generation_config = gen_config
    )

    generation_previous_dt = output_generation_step_info(prompt_num, generation_start_dt, generation_previous_dt, "Finished generating output. Decoding token IDs to text.")

    gen_text = tokenizer.batch_decode(gen_tokens)[0]

    generation_previous_dt = output_generation_step_info(prompt_num, generation_start_dt, generation_previous_dt, "Finished decoding token IDs to text.")

    print(f"Model input: '{prompt}'\nModel output: '{gen_text}'")

    generation_end_dt = datetime.datetime.now()
    generation_end_dt_string = generation_end_dt.replace(microsecond = 0).isoformat()
    generation_elapsed = generation_end_dt - generation_start_dt
    generation_elapsed_string = f"{generation_elapsed.total_seconds()}"
    print(f"Finished processing prompt {prompt_num} at {generation_end_dt_string} - {generation_elapsed_string} seconds elapsed in total for this prompt.")

def main():
    model_path = "/mnt/md0/Machine_Learning/LLMs/Microsoft/Phi-3-mini-128k-instruct"
    tokenizer_path = model_path

    prompts = []
    # increase this range to test for behaviour that changes with successive operations
    for i in range(0, 1):
        prompts.append("<|prompter|>Please tell me about GPTNeoX20B, the 20B-parameter autoregressive Transformer model developed by EleutherAI.<|endoftext|><|assistant|>")
        prompts.append("<|prompter|>Please tell me about GPTNeoX20B, the 20B-parameter autoregressive Transformer model developed by EleutherAI.<|endoftext|><|assistant|>Sure, I'd be happy to tell you about GPTNeoX20B")
        prompts.append("GPTNeoX20B is a 20B-parameter autoregressive Transformer model developed by EleutherAI.")

    dtype_list = [ None, torch.float32, torch.float16, torch.bfloat16 ]

    for current_dtype in dtype_list:
        start_dt = datetime.datetime.now()
        start_dt_string = start_dt.replace(microsecond = 0).isoformat()

        print(f"Starting test with torch_dtype = {current_dtype} at {start_dt_string}")

        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype = current_dtype
            ).to(device = "cpu").eval()

        print(f"model.dtype (when loaded with torch_dtype {current_dtype}) = {model.dtype}")

        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

        #if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        #    tokenizer.pad_token_id = tokenizer.eos_token_id
        #    print(f"tokenizer.pad_token_id has been set to {tokenizer.pad_token_id} to match tokenizer.eos_token_id")

        for prompt_num in range(0, len(prompts)):
           generate_text(model, tokenizer, prompts[prompt_num], prompt_num)

        end_dt = datetime.datetime.now()
        end_dt_string = end_dt.replace(microsecond = 0).isoformat()
        elapsed = end_dt - start_dt
        elapsed_string = f"{elapsed.total_seconds()}"

        print(f"Finished test with torch_dtype = {current_dtype} at {end_dt_string} - {elapsed_string} seconds elapsed for the entire test cycle using this dtype.")

        del model
        gc.collect()

main()

Relevant output for Phi-3-mini-128k:

Finished processing prompt 0 at 2024-11-11T14:06:13 - 17.336155 seconds elapsed in total for this prompt.
Finished processing prompt 1 at 2024-11-11T14:06:22 - 9.018136 seconds elapsed in total for this prompt.
Finished processing prompt 2 at 2024-11-11T14:06:44 - 22.536667 seconds elapsed in total for this prompt.
Finished test with torch_dtype = None at 2024-11-11T14:06:44 - 67.624459 seconds elapsed for the entire test cycle using this dtype.
Finished processing prompt 0 at 2024-11-11T14:07:02 - 16.449278 seconds elapsed in total for this prompt.
Finished processing prompt 1 at 2024-11-11T14:07:11 - 8.645172 seconds elapsed in total for this prompt.
Finished processing prompt 2 at 2024-11-11T14:07:33 - 22.538802 seconds elapsed in total for this prompt.
Finished test with torch_dtype = torch.float32 at 2024-11-11T14:07:33 - 48.931164 seconds elapsed for the entire test cycle using this dtype.
Finished processing prompt 0 at 2024-11-11T14:09:00 - 85.79342 seconds elapsed in total for this prompt.
Finished processing prompt 1 at 2024-11-11T14:10:25 - 84.946433 seconds elapsed in total for this prompt.
Finished processing prompt 2 at 2024-11-11T14:11:52 - 86.384148 seconds elapsed in total for this prompt.
Finished test with torch_dtype = torch.float16 at 2024-11-11T14:11:52 - 257.981995 seconds elapsed for the entire test cycle using this dtype.
Finished processing prompt 0 at 2024-11-11T14:11:59 - 7.338564 seconds elapsed in total for this prompt.
Finished processing prompt 1 at 2024-11-11T14:12:03 - 3.707063 seconds elapsed in total for this prompt.
Finished processing prompt 2 at 2024-11-11T14:12:13 - 9.802746 seconds elapsed in total for this prompt.
Finished test with torch_dtype = torch.bfloat16 at 2024-11-11T14:12:13 - 21.009795 seconds elapsed for the entire test cycle using this dtype.

In this case, the difference is 5 times worse than float32, and 10 times worse than bfloat16 overall. There seems to be some kind of fixed overhead causing the issue, because the processing times for both Phi-3-mini-128k and GPT-NeoX in float16 form are virtually identical, even when they vary by several times for the data in other formats.

I assume the discrepancy is at least sort of a known issue to the Transformers developers, but I only discovered it myself when trying to debug a different problem. Adding a runtime warning and maybe an explicit warning in the documentation seems like it would be a good idea.

Expected behavior

If CPU processing is performed using a very inefficient format that is also commonly suggested as a way to reduce the memory footprint, I would expect Transformers to issue a warning.

Kevin0624 commented 2 days ago

Hi @blincoln-bf

Could you provide the model of your CPU? Generally, CPUs like the 13th Gen Intel(R) Core(TM) i5-13600K do not have specialized instruction sets for BF16 and FP16 data formats. This results in the CPU needing to convert data types to FP32 for computation and then back to BF16 or FP16. This back-and-forth process consumes a significant amount of time.

cw235 commented 1 day ago

I think your approach to benchmarking performance using different torch_dtypes is thorough and provides valuable insights into the performance differences on a CPU. The significant slowdown with torch.float16 across different models reinforces the need for warnings when this dtype is used on CPUs.

Your might consider sharing their findings on relevant forums or with the maintainers of the Transformers library, as this could help other users avoid similar pitfalls. Adding an informative warning message when torch.float16 is used on a CPU is a practical and user-friendly solution.

Overall, the methodology and detailed performance metrics make a strong case for more awareness around dtype performance implications.

blincoln-bf commented 1 day ago

Hi @Kevin0624.

The system where I ran that benchmark script has an AMD Ryzen 9 7950X3D (16 cores) and 128 GiB of RAM, in addition to an RTX 4090.

Regardless of the reason, it seems like warning the user in the documentation, or if they've specified a very inefficient torch_dtype for the device would be a good idea, at least in the case of float16, since the .half() approach is suggested in so many Transformers tutorials. As I mentioned, that script just demonstrates text generation, where the difference is about an order of magnitude. For more complex work, the difference can be truly unbelievable. The most dramatic example I've found so far is with gpt-j-6b, where an operation in my project takes 25+ hours on the CPU with torch_dtype = float16, versus three minutes if the model is loaded with torch_dtype = bfloat16. For gpt-neox-20b, its 10+ hours versus 20 minutes. For Phi-3-medium-128k-instruct, it's four hours versus 2 minutes, 30 seconds. This is with no other changes to the code, only that one flag during the model load operation.

However, you are correct in that it's very device-dependent. I ran the same benchmark on an M1 MacBook (using Qwen2-0.5B-Instruct, because that MacBook only has 16 GiB of RAM), and got very different results:

Finished processing prompt 0 at 2024-11-12T11:45:44 - 2.875962 seconds elapsed in total for this prompt.
Finished processing prompt 1 at 2024-11-12T11:45:46 - 1.427849 seconds elapsed in total for this prompt.
Finished processing prompt 2 at 2024-11-12T11:45:49 - 3.508065 seconds elapsed in total for this prompt.
Finished test with torch_dtype = None at 2024-11-12T11:45:49 - 8.919457 seconds elapsed for the entire test cycle using this dtype.
Finished processing prompt 0 at 2024-11-12T11:45:52 - 2.56628 seconds elapsed in total for this prompt.
Finished processing prompt 1 at 2024-11-12T11:45:54 - 1.469897 seconds elapsed in total for this prompt.
Finished processing prompt 2 at 2024-11-12T11:45:57 - 3.528596 seconds elapsed in total for this prompt.
Finished test with torch_dtype = torch.float32 at 2024-11-12T11:45:57 - 7.982476 seconds elapsed for the entire test cycle using this dtype.
Finished processing prompt 0 at 2024-11-12T11:45:59 - 0.924022 seconds elapsed in total for this prompt.
Finished processing prompt 1 at 2024-11-12T11:45:59 - 0.484926 seconds elapsed in total for this prompt.
Finished processing prompt 2 at 2024-11-12T11:46:00 - 0.912076 seconds elapsed in total for this prompt.
Finished test with torch_dtype = torch.float16 at 2024-11-12T11:46:00 - 2.669201 seconds elapsed for the entire test cycle using this dtype.
Finished processing prompt 0 at 2024-11-12T11:46:01 - 0.842285 seconds elapsed in total for this prompt.
Finished processing prompt 1 at 2024-11-12T11:46:02 - 0.600835 seconds elapsed in total for this prompt.
Finished processing prompt 2 at 2024-11-12T11:46:03 - 1.044211 seconds elapsed in total for this prompt.
Finished test with torch_dtype = torch.bfloat16 at 2024-11-12T11:46:03 - 2.711551 seconds elapsed for the entire test cycle using this dtype.

I'll run that script with Qwen2-0.5B-Instruct on a Windows laptop with an Intel processor and 16 GiB of RAM a little later today for comparison.

blincoln-bf commented 1 day ago

Script output for Qwen2-0.5B-Instruct on a Windows laptop with an Intel Core i5-10310U and 16 GiB of RAM. On this device, bfloat16 performance is almost as slow as float16, so I'm even more glad I went with AMD and Linux for my ML research system :).

Finished processing prompt 0 at 2024-11-12T12:45:03 - 9.454424 seconds elapsed in total for this prompt.
Finished processing prompt 1 at 2024-11-12T12:45:07 - 4.606776 seconds elapsed in total for this prompt.
Finished processing prompt 2 at 2024-11-12T12:45:18 - 10.964283 seconds elapsed in total for this prompt.
Finished test with torch_dtype = None at 2024-11-12T12:45:18 - 29.24945 seconds elapsed for the entire test cycle using this dtype.
Finished processing prompt 0 at 2024-11-12T12:45:29 - 8.341307 seconds elapsed in total for this prompt.
Finished processing prompt 1 at 2024-11-12T12:45:33 - 4.570537 seconds elapsed in total for this prompt.
Finished processing prompt 2 at 2024-11-12T12:45:44 - 10.71891 seconds elapsed in total for this prompt.
Finished test with torch_dtype = torch.float32 at 2024-11-12T12:45:44 - 25.648423 seconds elapsed for the entire test cycle using this dtype.
Finished processing prompt 0 at 2024-11-12T12:46:45 - 59.440162 seconds elapsed in total for this prompt.
Finished processing prompt 1 at 2024-11-12T12:47:37 - 52.715362 seconds elapsed in total for this prompt.
Finished processing prompt 2 at 2024-11-12T12:48:47 - 69.450205 seconds elapsed in total for this prompt.
Finished test with torch_dtype = torch.float16 at 2024-11-12T12:48:47 - 182.775335 seconds elapsed for the entire test cycle using this dtype.
Finished processing prompt 0 at 2024-11-12T12:49:44 - 56.260634 seconds elapsed in total for this prompt.
Finished processing prompt 1 at 2024-11-12T12:50:29 - 45.242265 seconds elapsed in total for this prompt.
Finished processing prompt 2 at 2024-11-12T12:51:29 - 59.411733 seconds elapsed in total for this prompt.
Finished test with torch_dtype = torch.bfloat16 at 2024-11-12T12:51:29 - 161.52774 seconds elapsed for the entire test cycle using this dtype.