huggingface / transformers

đŸ¤— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.08k stars 26.31k forks source link

Model load when dtypes match is broken #32089

Closed zucchini-nlp closed 1 month ago

zucchini-nlp commented 1 month ago

System Info

PR on fast init (#31771) seems to have broken Chameleon loading. When I try to load the model with the same dtype on cpu as the weights are (bf16), inference fails due to dtype mismatch. It doesn't fail if load on gpu with device_map="cuda" though

Weights in the VQ module now are in fp32, while the LM module is in bf16. I still can make it work by not casting bf16 on pixel_values but that is not an expected behavior and causes inconsistencies, because if I load with fp16 then I would have to cast also inputs to fp16.

Who can help?

No response

Information

Tasks

Reproduction

from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
import torch
import requests
from PIL import Image

DTYPE = torch.bfloat16
model = ChameleonForConditionalGeneration.from_pretrained("facebook/chameleon-7b", torch_dtype=DTYPE).to("cuda:0")
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")

prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
image = Image.open(requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw)
image_2 = Image.open(requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw)
inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, DTYPE) # removing DTYPE solves here for bf16, but needed if fp16

generated_ids = model.generate(**inputs_cat, max_new_tokens=100, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.2)

Expected behavior

composite models, like Chameleon should not be breaking when loading with same dtype as their weights

@muellerzr @ArthurZucker I didn't dive deep yet, guess you will be faster in spotting the root cause

zucchini-nlp commented 1 month ago

My bad, found a flag that disables fast loading (_supports_param_buffer_assignment)

Anyway, I would like to understand more on why fast init fails to keep same dtype for vision module, and if we will be able to support these kind of models, so leaving an issue open :)

muellerzr commented 1 month ago

@zucchini-nlp does that model have the flag? If not, could you make a PR to do so as a quick fix?

indeed, what really is the issue is that chunk being initialized in float32 even with the explicit dtype

zucchini-nlp commented 1 month ago

Yes, made a PR (https://github.com/huggingface/transformers/pull/32091) to fix. Also I found other composite models like LLaVa aren't broken, so I have no idea what was wrong with Chameleon

zucchini-nlp commented 1 month ago

I found why it was defaulting to fp32 in vision model and bf16 in LM. The original weighs were loaded and converted in that precision, so the fast init was loading them in the same dtype as the weights. Closing as resolved!