huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.19k stars 5.21k forks source link

NotImplementedError: Cannot copy out of meta tensor; no data! coz, Some weights of the model checkpoint were not used when initializing CLIPTextModelWithProjection #7506

Open shabri-arrahim opened 5 months ago

shabri-arrahim commented 5 months ago

Describe the bug

I try to load a .safetensors file and save it as diffusers type model and I got Some weights of the model checkpoint were not used when initializing CLIPTextModelWithProjection: ['text_model.embeddings.position_ids'] warning

import torch
from diffusers import StableDiffusionXLPipeline
from diffusers import DPMSolverMultistepScheduler

model_params = {
    "pretrained_model_link_or_path": "/workspace/playground-v2.5-1024px-aesthetic.fp16.safetensors",
    "torch_dtype": torch.float16,
}

pipe = StableDiffusionXLPipeline.from_single_file(**model_params)

pipe.save_pretrained(
    save_directory="/workspace/playground-v2.5.fp16",
    safe_serialization=True,
    variant="fp16",
    push_to_hub=False,
)

When I try to load it, I got NotImplementedError: Cannot copy out of meta tensor; no data! error

pipe = StableDiffusionXLPipeline.from_pretrained("/workspace/playground-v2.5.fp16", torch_dtype=torch.float16, variant="fp16")
pipe = pipe.to("cuda")

Reproduction

import torch
from diffusers import StableDiffusionXLPipeline
from diffusers import DPMSolverMultistepScheduler

model_params = {
    "pretrained_model_link_or_path": "/workspace/playground-v2.5-1024px-aesthetic.fp16.safetensors",
    "torch_dtype": torch.float16,
}

pipe = StableDiffusionXLPipeline.from_single_file(**model_params)

pipe.save_pretrained(
    save_directory="/workspace/playground-v2.5.fp16",
    safe_serialization=True,
    variant="fp16",
    push_to_hub=False,
)

pipe = StableDiffusionXLPipeline.from_pretrained("/workspace/playground-v2.5.fp16", torch_dtype=torch.float16, variant="fp16")
pipe = pipe.to("cuda")

Logs

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[11], line 2
      1 pipe = StableDiffusionXLPipeline.from_pretrained("/workspace/playground-v2.5.fp16", torch_dtype=torch.float16, variant="fp16")
----> 2 pipe = pipe.to("cuda")
      3 pipe.safety_checker = None
      4 # pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
      5 
      6 #3

File /usr/local/lib/python3.10/dist-packages/diffusers/pipelines/pipeline_utils.py:418, in DiffusionPipeline.to(self, *args, **kwargs)
    414     logger.warning(
    415         f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}."
    416     )
    417 else:
--> 418     module.to(device, dtype)
    420 if (
    421     module.dtype == torch.float16
    422     and str(device) in ["cpu"]
    423     and not silence_dtype_warnings
    424     and not is_offloaded
    425 ):
    426     logger.warning(
    427         "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It"
    428         " is not recommended to move them to `cpu` as running them will fail. Please make"
   (...)
    431         " `torch_dtype=torch.float16` argument, or use another device for inference."
    432     )

File /usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:2576, in PreTrainedModel.to(self, *args, **kwargs)
   2571     if dtype_present_in_args:
   2572         raise ValueError(
   2573             "You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
   2574             " `dtype` by passing the correct `torch_dtype` argument."
   2575         )
-> 2576 return super().to(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1145, in Module.to(self, *args, **kwargs)
   1141         return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
   1142                     non_blocking, memory_format=convert_to_format)
   1143     return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
-> 1145 return self._apply(convert)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797, in Module._apply(self, fn)
    795 def _apply(self, fn):
    796     for module in self.children():
--> 797         module._apply(fn)
    799     def compute_should_use_set_data(tensor, tensor_applied):
    800         if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    801             # If the new tensor has compatible tensor type as the existing tensor,
    802             # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    807             # global flag to let the user control whether they want the future
    808             # behavior of overwriting the existing tensor or not.

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797, in Module._apply(self, fn)
    795 def _apply(self, fn):
    796     for module in self.children():
--> 797         module._apply(fn)
    799     def compute_should_use_set_data(tensor, tensor_applied):
    800         if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    801             # If the new tensor has compatible tensor type as the existing tensor,
    802             # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    807             # global flag to let the user control whether they want the future
    808             # behavior of overwriting the existing tensor or not.

    [... skipping similar frames: Module._apply at line 797 (3 times)]

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:797, in Module._apply(self, fn)
    795 def _apply(self, fn):
    796     for module in self.children():
--> 797         module._apply(fn)
    799     def compute_should_use_set_data(tensor, tensor_applied):
    800         if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    801             # If the new tensor has compatible tensor type as the existing tensor,
    802             # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    807             # global flag to let the user control whether they want the future
    808             # behavior of overwriting the existing tensor or not.

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:820, in Module._apply(self, fn)
    816 # Tensors stored in modules are graph leaves, and we don't want to
    817 # track autograd history of `param_applied`, so we have to use
    818 # `with torch.no_grad():`
    819 with torch.no_grad():
--> 820     param_applied = fn(param)
    821 should_use_set_data = compute_should_use_set_data(param, param_applied)
    822 if should_use_set_data:

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1143, in Module.to.<locals>.convert(t)
   1140 if convert_to_format is not None and t.dim() in (4, 5):
   1141     return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
   1142                 non_blocking, memory_format=convert_to_format)
-> 1143 return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)

NotImplementedError: Cannot copy out of meta tensor; no data!

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

Who can help?

@yiyixuxu @sayakpaul @DN6

sayakpaul commented 5 months ago

How was pipe initialized?

shabri-arrahim commented 5 months ago

How was pipe initialized?

pipe = StableDiffusionXLPipeline.from_single_file(**model_params) @sayakpaul

sayakpaul commented 5 months ago

Where "playground-v2.5-1024px-aesthetic.fp16.safetensors" is coming from?

shabri-arrahim commented 5 months ago

Where "playground-v2.5-1024px-aesthetic.fp16.safetensors" is coming from?

I download it via this link: https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/resolve/main/playground-v2.5-1024px-aesthetic.fp16.safetensors

sayakpaul commented 5 months ago

Cc: @DN6 could you take a look?

shabri-arrahim commented 5 months ago

I still didn't know why this happen, but at least I know that the text_endocer_2 are failed to be loaded. So if someone experiencing a similar issue like me, you might find this solution helpful (although I must admit it’s not the most efficient approach 🙃).

import torch
from diffusers import StableDiffusionXLPipeline
from diffusers import DPMSolverMultistepScheduler
from transformers import CLIPTextModelWithProjectio

from safetensors.torch import load_file as safe_load
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_open_clip_checkpoint

checkpoint = safe_load("/workspace/playground-v2.5-1024px-aesthetic.fp16.safetensors", device="cpu")

# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
while "state_dict" in checkpoint:
    checkpoint = checkpoint["state_dict"] 

config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280}
text_encoder_2 = convert_open_clip_checkpoint(
    checkpoint,
    config_name,
    prefix="conditioner.embedders.1.model.",
    has_projection=True,
    local_files_only=False,
    **config_kwargs,
)

model_params = {
    "pretrained_model_link_or_path": "/workspace/playground-v2.5-1024px-aesthetic.fp16.safetensors",
    "torch_dtype": torch.float16,
    "use_safetensors": True,
    "add_watermarker": False,
}

pipe = StableDiffusionXLPipeline.from_single_file(**model_params)
pipe.text_encoder_2 = text_encoder_2

pipe.save_pretrained(
    save_directory="/workspace/playground-v2.5.fp16",
    safe_serialization=True,
    variant="fp16",
    push_to_hub=False,
)
sayakpaul commented 5 months ago

Cc: @DN6 could you take a look?

DN6 commented 5 months ago

Hmm strange. Some tensors are not being saved in the OpenCLIP model when calling save_pretrained. Taking a look.

DN6 commented 5 months ago

Hi @shabri-arrahim tracked the issue down to these lines https://github.com/huggingface/diffusers/blob/2b04ec2ff7270d2044410378b04d85a194fa3d4a/src/diffusers/loaders/single_file_utils.py#L1238-L1240

When accelerate is installed and saving to safetensors, we attempt to save those weights as shared tensors (which the safetensor format currently doesn't support) and so they are omitted and saved as meta tensors, which leads to the error when you try loading the model.

I'll include a fix for this is in the #7496

github-actions[bot] commented 4 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

crapthings commented 4 months ago

hello i use everydream2 train a model, but i can't load trained model

before convert

Please load the component before passing it in as an argument to from_single_file.

text_encoder = CLIPTextModel.from_pretrained('...') pipe = StableDiffusionControlNetPipeline.from_single_file(, text_encoder=text_encoder)

after convert

image

convert doesn't work too

NotImplementedError: Cannot copy out of meta tensor; no data!

DN6 commented 4 months ago

@crapthings could you create a separate issue please with a reproducible code example (no screenshots). Not sure if your problem is related.

crapthings commented 4 months ago

@crapthings could you create a separate issue please with a reproducible code example (no screenshots). Not sure if your problem is related.

revert to 0.27.2 works

github-actions[bot] commented 5 days ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.