xichenpan / ARLDM

Official Pytorch Implementation of Synthesizing Coherent Story with Auto-Regressive Latent Diffusion Models
https://arxiv.org/abs/2211.10950
MIT License
182 stars 28 forks source link

updating Stable Diffusion to 2.1? #12

Closed KyonP closed 1 year ago

KyonP commented 1 year ago

Thank you for your repository. It greatly helped.

Is there a plan to update the current version of stable diffusion used in your code to 2.1?

I've just tried to naively change the path to "stabilityai/stable-diffusion-2-1-base" and failed on "models.diffusers_override.unet_2d_blocks.py"; this line "out_channels // attn_num_head_channels,"

class CrossAttnUpBlock2D(nn.Module):
    def __init__(
            [omitted]

            resnets.append(
                ResnetBlock2D(
                    in_channels=resnet_in_channels + res_skip_channels,
                    out_channels=out_channels,
                    temb_channels=temb_channels,
                    eps=resnet_eps,
                    groups=resnet_groups,
                    dropout=dropout,
                    time_embedding_norm=resnet_time_scale_shift,
                    non_linearity=resnet_act_fn,
                    output_scale_factor=output_scale_factor,
                    pre_norm=resnet_pre_norm,
                )
            )
            attentions.append(
                Transformer2DModel(
                    attn_num_head_channels,
                    **out_channels // attn_num_head_channels,**
                    in_channels=out_channels,
                    num_layers=1,
                    cross_attention_dim=cross_attention_dim,
                    norm_num_groups=resnet_groups,
                )
            )

There seem to be differences in "attn_num_head_channels" between the two versions.

I am sorry that I cannot provide error message because I have already reverted my code.

In v2.1, "attn_num_head_channels" is a list, not an int.

Before I hack into it, I thought it would be a good time to ask if this would have been tried. I hope for your generous advice on this.

xichenpan commented 1 year ago

Hi @KyonP , sorry for the delayed reply, I was busy with ICCV last week. I don't think you should modify the unet architecture, this will cause a loading failure (shape dismatch). Instead, you can add a linear layer between the CLIP and BLIP conditon and the Unet of stable diffusion 2.1. Also you may need to modify some of the diffusion code with reference of stable diffusion 2.1 official implementation.

KyonP commented 1 year ago

Thank you for your reply.

I thought you submitted this work at a conference and working on a process (CVPR rebuttal or ICCV submission) :) I hope you have good news.

I hacked into your code to upgrade SD 2.1, and unfortunately, failed as you expected, too many lines to modify.

I will try your advice.

xichenpan commented 1 year ago

@KyonP Great! Thanks for your understanding! If you have further question on moving to SD 2.1, feel free to leave me a message~