huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.46k stars 5.45k forks source link

SD3ControlNetModel.from_transformer load weights error #9012

Open FYRichie opened 4 months ago

FYRichie commented 4 months ago

Describe the bug

When specifying num_layers in SD3ControlNetModel.from_transformer, the function results in an error of unexpected key(s). I think this is because that in SD3ControlNetModel, there is only num_layers of transformer blocks, while in SD3Transformer2DModel, there are 24 transformer blocks. I think it can be fixed by looping over the transformer blocks in the controlnet for num_layers times when loading the weights from SD3 transformer.

Reproduction

from diffusers import SD3Transformer2DModel
from diffusers.models.controlnet_sd3 import SD3ControlNetModel

n_layers = 12  # Value that is less than 24

transformer = SD3Transformer2DModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="transformer")
controlnet = SD3ControlNetModel.from_transformer(transformer, n_layers)

Logs

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/root/anaconda3/envs/geolight/lib/python3.9/site-packages/diffusers/models/controlnet_sd3.py", line 251, in from_transformer
    controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict())
  File "/root/anaconda3/envs/geolight/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ModuleList:
        Unexpected key(s) in state_dict: "12.norm1.linear.weight", "12.norm1.linear.bias", "12.norm1_context.linear.weight", "12.norm1_context.linear.bias", "12.attn.to_q.weight", "
12.attn.to_q.bias", "12.attn.to_k.weight", "12.attn.to_k.bias", "12.attn.to_v.weight", "12.attn.to_v.bias", "12.attn.add_k_proj.weight", "12.attn.add_k_proj.bias", "12.attn.add_v_pr
oj.weight", "12.attn.add_v_proj.bias", "12.attn.add_q_proj.weight", "12.attn.add_q_proj.bias", "12.attn.to_out.0.weight", "12.attn.to_out.0.bias", "12.attn.to_add_out.weight", "12.a
ttn.to_add_out.bias", "12.ff.net.0.proj.weight", "12.ff.net.0.proj.bias", "12.ff.net.2.weight", "12.ff.net.2.bias", "12.ff_context.net.0.proj.weight", "12.ff_context.net.0.proj.bias
", "12.ff_context.net.2.weight", "12.ff_context.net.2.bias"...

System Info

Who can help?

@haofanwang

a-r-r-o-w commented 4 months ago

Thanks for reporting! Would you be interested in opening a PR for the proposed fix?

FYRichie commented 4 months ago

Sure

github-actions[bot] commented 2 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.