huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.02k stars 27.01k forks source link

High cpu memory usage as bf16 model is auto loaded as fp32 #34743

Open Qubitium opened 8 hours ago

Qubitium commented 8 hours ago

System Info

Ubuntu 24.04 Transformers 4.46.2 Accelerator 1.1.1 Safetensor 0.4.5

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

Unexpected 2x cpu memory usage due to bf16 safetensor loaded as float32 on device=cpu.

Manually passing torch_dtype=torch.bfloat16 has no such issue but this should not be necessary since both model.config and safentensor files has proper bfloat16.

Sample reproducing code:

import torch
from transformers import AutoModelForCausalLM
import psutil

# model is stored as bf16 safetensor
model_file = "meta-llama/Llama-3.2-1B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_file)

process = psutil.Process()
memory_info = process.memory_info()
print(f"RSS (Resident Set Size): {memory_info.rss / 1024 / 1024:.2f} MB")
print(f"VMS (Virtual Memory Size): {memory_info.vms / 1024 / 1024:.2f} MB")

print(f"model config dtype is {model.config.torch_dtype}")
assert model.config.torch_dtype == torch.bfloat16

p = model.parameters().__next__()
print(f"model first parameter dtype: {p.dtype}, device: {p.device}")
assert p.device == torch.device("cpu")
assert p.dtype == torch.bfloat16

Code output:

Traceback (most recent call last):
  File "/GPTQModel/test.py", line 20, in <module>
    assert p.dtype == torch.bfloat16
           ^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
RSS (Resident Set Size): 5189.39 MB <----- High memory usage
VMS (Virtual Memory Size): 41335.09 MB
model config dtype is torch.bfloat16
model first parameter dtype: torch.float32, device: cpu. <----- Wrong dtype 

Expected behavior

Modify above code pass torch_dtype=torch.bfloat16 to from_pretrained and memory usage is normal/expected:

RSS (Resident Set Size): 603.85 MB <----- Expected memory usage
VMS (Virtual Memory Size): 40607.80 MB
model config dtype is torch.bfloat16
model first parameter dtype: torch.bfloat16, device: cpu

There are two related issues here:

  1. bfloat16 wrongly inflated to float32 causing very high memory usage
  2. safetensor weights should be lazy loading so it should only be around 600MB of weights loaded

Manually passing dtype=bfloat16 to from_pretrained fixes this issue.

LysandreJik commented 6 hours ago

Hey @Qubitium, the model was indeed serialized as bf16, but here you're not specifying in which dtype you would like to load it.

We follow torch's default loading mechanism, which is to automatically load it in the default torch.dtype (here, fp32) so as to be compatible with all hardwares and setups.

In order to update the dtype in which it should be loaded, please change this line:

- model = AutoModelForCausalLM.from_pretrained(model_file)
+ model = AutoModelForCausalLM.from_pretrained(model_file, torch_dtype=torch.bfloat16)

You can also use 'auto' so as to respect the dtype of the weights themselves:

- model = AutoModelForCausalLM.from_pretrained(model_file)
+ model = AutoModelForCausalLM.from_pretrained(model_file, torch_dtype='auto')

You can read more about this in the from_pretrained documentation which I am pasting below:

image

Qubitium commented 4 hours ago

@LysandreJik It's 2024 and I would like to propose that the default float32 be modified. Please read the below with a light heart.

Reasons:

Overall, accept the config.json default as truth unless there is an override, or the default is really in-comptible with gpu/cpu: when a device does not physically support it model specified dtype.

torch_dtype (`str` or `torch.dtype`, *optional*):
     Override the default `torch.dtype` and load the model under a specific `dtype`. The different options
     are:

     1. `torch.float16` or `torch.bfloat16` or `torch.float`: load in a specified
      `dtype`, ignoring the model's `config.torch_dtype` if one exists. If not specified
      - the model will get loaded in `torch.float` (fp32).

      2. `"auto"` - A `torch_dtype` entry in the `config.json` file of the model will be
      attempted to be used. If this entry isn't found then next check the `dtype` of the first weight in
      the checkpoint that's of a floating point type and use that as `dtype`. This will load the model
      using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how
      the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.

      3. A string that is a valid `torch.dtype`. E.g. "float32" loads the model in `torch.float32`, "float16" loads in `torch.float16` etc.

      <Tip>

      For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or
      reach out to the authors and ask them to add this information to the model's card and to insert the
      `torch_dtype` entry in `config.json` on the hub.

      </Tip>