huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
129.62k stars 25.74k forks source link

Nested from_pretrained() gives warnings loading weights - "copying from a non-meta parameter" #31544

Open jamt9000 opened 1 month ago

jamt9000 commented 1 month ago

I am seeing the pytorch warnings "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model, which is a no-op" when loading the CLIP vision tower of a LLaVA model (in the LLaVA / LLaVA-NeXT repos, not the HF versions).

It seems to happen because LlavaLlamaForCausalLM.from_pretrained() calls CLIPVisionModel.from_pretrained() to create the vision tower, but the context managers for disabling weight initialisation (due to _fast_init and/or low_cpu_mem_usage being True?) are then active when the second from_pretained tries to load weights from a checkpoint.

https://github.com/huggingface/transformers/blob/74a207404e8d4524d1fdc4aa23789694f9eef347/src/transformers/modeling_utils.py#L3704-L3706

This means that _load_pretrained_model() for loading the CLIP weights will break because it won't handle the meta tensors, for example here it will incorrectly use id_tensor_storage() on meta tensors:

https://github.com/huggingface/transformers/blob/74a207404e8d4524d1fdc4aa23789694f9eef347/src/transformers/modeling_utils.py#L4018-L4028

This seems to be related to issues here https://github.com/haotian-liu/LLaVA/issues/1122

System Info

- `transformers` version: 4.41.2
- Platform: Linux-5.15.0-1043-aws-x86_64-with-glibc2.31
- Python version: 3.11.7
- Huggingface_hub version: 0.23.3
- Safetensors version: 0.4.3
- Accelerate version: 0.31.0
- Accelerate config:    not found
- PyTorch version (GPU?): 2.1.2+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: Yes, A100
- Using distributed or parallel set-up in script?: no

Who can help?

@SunMarc

Reproduction

Can be seen when loading the LLaVA-NeXT as in https://github.com/LLaVA-VL/LLaVA-NeXT/blob/inference/docs/LLaVA-NeXT.md

from llava.model.builder import load_pretrained_model
pretrained = "lmms-lab/llama3-llava-next-8b"
model_name = "llava_llama3"
device = "cuda"
device_map = "auto"
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map)

The warning can also be seen when using init_empty_weights() explicitly like this (although I find it more understandable why that might not work compared to the low_cpu_mem_usage boolean which may get force-set to true in modeling_utils.py)

from transformers import CLIPVisionModel
from accelerate import init_empty_weights

with init_empty_weights():
    clip = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14-336")

Expected behavior

It's unclear whether the loaded model is correct or not, but I would expect the weights to be loaded correctly without warnings when using CLIPVisionModel.from_pretrained(). If it is not intended to call CLIPVisionModel.from_pretrained() within the init of another model PreTrainedModel then it would be good to have an error explaining that or some safety checks for meta tensors in _load_pretrained_model.

amyeroberts commented 3 weeks ago

Hi @jamt9000, thanks for raising this issue!

Indeed, I think you're right that the nested from_pretrained is causing this. from_pretrained is to be used by the parent class which should inherit from PretrainedModel and bundles all other modules together.

You'll see that within the transformers repo, we create submodules using from_config e.g. like here for llava. This ensures there's a single call to from_pretrained which will handle weight loading for the entire model. This is important for various reasons. As you highlighted, there's certain logic such as context managers, which are used to load and instantiate weights and assumed to be called once. from_pretrained also takes other arguments such as device_map, which will allocate weights to certain devices, and if "auto" is specified, try to put as much of the model on the available accelerator. If there are multiple modules calling from_pretrained then it's not possible to correctly calculate and allocate the weights as this requires a global view.

TBH, this isn't an issue that we've encountered before, and a simple warning is probably best. If you or anyone else in the community would like to open a PR to add we'd be happy to review!

github-actions[bot] commented 12 hours ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.