huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
23.97k stars 4.93k forks source link

Error(s) in initializing SD3ControlNetModel by from_transformer #8723

Closed ChenhLiwnl closed 2 hours ago

ChenhLiwnl commented 1 week ago

Describe the bug

WechatIMG659

Reproduction

from diffusers.models.controlnet_sd3 import SD3ControlNetModel
from diffusers.models.transformers import SD3Transformer2DModel
transformer = SD3Transformer2DModel.from_pretrained( "stabilityai/stable-diffusion-3-medium-diffusers", subfolder="transformer")
controlnet = SD3ControlNetModel.from_transformer(transformer)

Logs

No response

System Info

Who can help?

No response

sayakpaul commented 1 week ago

Cc: @haofanwang

yiyixuxu commented 1 week ago

yeah I think we reversed the strict=False and strict=True, can you try this? and if it works, would you be willing to open a PR? https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/controlnet_sd3.py#L242

        config = transformer.config
        config["num_layers"] = num_layers or config.num_layers
        controlnet = cls(**config)

        if load_weights_from_transformer:
            controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
            controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
            controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
            controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)

            controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)

        return controlnet
ChenhLiwnl commented 1 week ago

yeah I think we reversed the strict=False and strict=True, can you try this? and if it works, would you be willing to open a PR? https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/controlnet_sd3.py#L242

        config = transformer.config
        config["num_layers"] = num_layers or config.num_layers
        controlnet = cls(**config)

        if load_weights_from_transformer:
            controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
            controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
            controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
            controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)

            controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)

        return controlnet

It still doesnt work and reports same bug

ChenhLiwnl commented 1 week ago

yeah I think we reversed the strict=False and strict=True, can you try this? and if it works, would you be willing to open a PR? https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/controlnet_sd3.py#L242

        config = transformer.config
        config["num_layers"] = num_layers or config.num_layers
        controlnet = cls(**config)

        if load_weights_from_transformer:
            controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
            controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
            controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
            controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)

            controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)

        return controlnet

controlnet.transformer_blocks and transformer.transformer_blocks are different at the last block and I'm trying to find out why. I think that is the reason? so i set the last layer of controlnet.transformer_blocks as JointTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=self.inner_dim, context_pre_only=True, ) and this time it works. But I'm not very sure it is correct to do so

ChenhLiwnl commented 1 week ago

SD3's transformer blocks are: self.transformer_blocks = nn.ModuleList( [ JointTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.inner_dim, context_pre_only=i == num_layers - 1, ) for i in range(self.config.num_layers) ] ) while sd3_controlnet's transformer blocks are: self.transformer_blocks = nn.ModuleList( [ JointTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.inner_dim, context_pre_only= False, ) for i in range(self.config.num_layers) ] ) maybe this is the problem?

haofanwang commented 1 week ago

@wangqixun Can you explain a bit here? The last block is different.

ChenhLiwnl commented 1 week ago

@wangqixun Can you explain a bit here? The last block is different.

In fact the example controlnet (InstantX/SD3-Controlnet-Canny) can be loaded correctly So its a little bit confusing...

haofanwang commented 1 week ago

No worry, I'm testing. Will update soon.

haofanwang commented 1 week ago

The thing is kind of different in transformer-based ControlNet, because we don't fix the number of layers in ControlNet. While in UNet-based ControlNet, it always uses down_blocks and mid_block.

(1) If we set context_pre_only=i == num_layers - 1 in SD3ControlNetModel, we have to set num_layers as the same as SD3 base model. If not, there will be a size mismatch error. But it is not suggested, because in such case, the ControlNet is very heavy and more like a ReferenceNet. As you can see from our released checkpoints, the num_layers for ControlNet is 6 or 12, aka to half copy of UNet.

(2) So, our current solution is set num_layers in from_transformer to be 12 instead of None by default, then you can freely load weights from transformer, because we only use intermediate layers whose context_pre_only are False. The only obstacle is that we cannot set num_layers=24, as the last block is different.

In your usage, you can manually set by controlnet = SD3ControlNetModel.from_transformer(transformer, num_layers=6).

@ChenhLiwnl

ChenhLiwnl commented 4 days ago

sorry but another question, I noticed that in main branch the transformer blocks' attention_head_dim is set to be self.config.attention_head_dim, while in v0.29.2 released version it is self.inner_dim currently self.config.attention_head_dim of the released model seem to be 64 while self.inner_dim is 1536? it seems that they are not same value, so which one is right?

yiyixuxu commented 2 hours ago

@ChenhLiwnl it is a bug we fixed - main is correct https://github.com/huggingface/diffusers/pull/8608

yiyixuxu commented 2 hours ago

closing this now since the issue is resolved! :)