huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.99k stars 5.35k forks source link

[0.27.2]: from_single_file: 'AutoencoderKL' object has no attribute '__name__'. Did you mean: '__ne__'? #8320

Open crapthings opened 5 months ago

crapthings commented 5 months ago

Describe the bug

finetune a dreambooth model with ED2 but can't load with 0.27.2

the model work with a1111

Reproduction

pipe.from_single_file

https://huggingface.co/crapthings/diffusers-issue/resolve/main/base-ed2.safetensors

Logs

No response

System Info

pip install diffusers==0.27.2

Who can help?

No response

tolgacangoz commented 5 months ago

Why don't you try with the latest version (0.28.0)?

knoopx commented 5 months ago

@tolgacangoz i tried, now something else is broken xD

DN6 commented 5 months ago

@crapthings Can you install diffusers from main and try? Just merged a change to address this.

crapthings commented 5 months ago

@crapthings Can you install diffusers from main and try? Just merged a change to address this.

this model

huggingface.co/crapthings/diffusers-issue/resolve/main/base-ed2.safetensors

i upgrade to main but prodce new error

You have disabled the safety checker for <class 'diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
Traceback (most recent call last):
  File "/home/zznet/workspace/ai-pipe/prod-ai-pipe-replace-model/runpod_app.py", line 14, in <module>
    from pipe import text2img, getText2imgPipe, set_sampler, compel
  File "/home/zznet/workspace/ai-pipe/prod-ai-pipe-replace-model/pipe.py", line 57, in <module>
    text2imgPipe.enable_model_cpu_offload()
  File "/home/zznet/.local/lib/python3.10/site-packages/diffusers/pipelines/pipeline_utils.py", line 1063, in enable_model_cpu_offload
    self.to("cpu", silence_dtype_warnings=True)
  File "/home/zznet/.local/lib/python3.10/site-packages/diffusers/pipelines/pipeline_utils.py", line 431, in to
    module.to(device, dtype)
  File "/home/zznet/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1160, in to
    return self._apply(convert)
  File "/home/zznet/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  File "/home/zznet/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  File "/home/zznet/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  [Previous line repeated 2 more times]
  File "/home/zznet/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 833, in _apply
    param_applied = fn(param)
  File "/home/zznet/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1158, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
NotImplementedError: Cannot copy out of meta tensor; no data!
DN6 commented 5 months ago

Hi @crapthings it appears that the keys in your VAE have changed from what is usually expected in the single file checkpoint. If you compare this checkpoint you've shared https://huggingface.co/crapthings/diffusers-issue/blob/main/base.safetensors

You will notice that the VAE keys in your base tensors checkpoint look like

first_stage_model.decoder.mid.attn_1.q.bias

While the one shared in ED2 looks like

first_stage_model.decoder.mid.attn_1.to_q.weight

We expect the first type when loading in diffusers. Let me see if there's a way to support this, but it seems like the issue is with how ED2 is saving the checkpoint.

victorchall commented 4 months ago

This is just the SD1.5 checkpoint:

https://huggingface.co/panopstor/EveryDream

from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained("panopstor/EveryDream", subfolder="vae")

print(vae)
state_dict = vae.state_dict()
for k in state_dict:
    print(k)

Output:

...
decoder.mid_block.attentions.0.group_norm.weight
decoder.mid_block.attentions.0.group_norm.bias
decoder.mid_block.attentions.0.to_q.weight
decoder.mid_block.attentions.0.to_q.bias
decoder.mid_block.attentions.0.to_k.weight
decoder.mid_block.attentions.0.to_k.bias
decoder.mid_block.attentions.0.to_v.weight
decoder.mid_block.attentions.0.to_v.bias
decoder.mid_block.attentions.0.to_out.0.weight
decoder.mid_block.attentions.0.to_out.0.bias
decoder.mid_block.resnets.0.norm1.weight
...

The keys were committed by huggingface personnel this way.

Commit:

https://huggingface.co/panopstor/EveryDream/commit/0feeaee608654bc71a572dfefa9b83f3b74b204d

The conversion scripts between the LDM and Huggingface format have been shrugged off as "not supported" though people are constantly trying to convert back and forth because they want the single safetensor format for use with non-HF applications. It appears HF now does support the single "LDM" style safetensors format for loading I guess? But when loaded that way diffusers becomes particular about the key names. *The model loads and runs fine if it remains in HF folder/split format, i.e. directly from huggingface cache.

Snippet:

from diffusers import StableDiffusionPipeline
import torch

pipe = StableDiffusionPipeline.from_pretrained("panopstor/EveryDream", torch_dtype=torch.float16)
pipe = pipe.to('cuda')

image = pipe(prompt="a photo of an astronaut riding a horse on mars").images[0]
print(image)

Output: <PIL.Image.Image image mode=RGB size=512x512 at 0x70CAB32B5E40>

No issues... This is of course loading from the folders on that model, not a single .safetensors file.

This has been an issue since late 2022 I believe, as everyone runs away from the conversion issues.

victorchall commented 4 months ago

This issue crops up regularly with various software:

https://github.com/huggingface/diffusers/issues/7724

victorchall commented 4 months ago

If there is a truly supported way to convert from the split folder version to a single combined LDM-style file (and back), please let me know.

victorchall commented 4 months ago

You can repeat the above with "CompVis/stable-diffusion-v1-4" from https://huggingface.co/CompVis/stable-diffusion-v1-4/tree/main

from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")

print(vae)
state_dict = vae.state_dict()
for k in state_dict:
    print(k)
...
encoder.down_blocks.3.resnets.1.conv2.weight
encoder.down_blocks.3.resnets.1.conv2.bias
encoder.mid_block.attentions.0.group_norm.weight
encoder.mid_block.attentions.0.group_norm.bias
encoder.mid_block.attentions.0.to_q.weight
encoder.mid_block.attentions.0.to_q.bias
encoder.mid_block.attentions.0.to_k.weight
encoder.mid_block.attentions.0.to_k.bias
encoder.mid_block.attentions.0.to_v.weight
encoder.mid_block.attentions.0.to_v.bias
encoder.mid_block.attentions.0.to_out.0.weight
encoder.mid_block.attentions.0.to_out.0.bias
encoder.mid_block.resnets.0.norm1.weight
...
github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.