johannakarras / DreamPose

Official implementation of "DreamPose: Fashion Image-to-Video Synthesis via Stable Diffusion"
MIT License
962 stars 73 forks source link

unet.conv_in.weight[:, 3:] = torch.zeros(unet.conv_in.weight[:, 3:].shape) I assume 3 should be 4 #37

Closed zhangtao22 closed 1 year ago

zhangtao22 commented 1 year ago

def get_unet(pretrained_model_name_or_path, revision, resolution=256, n_poses=5):

Load pretrained UNet layers

unet = UNet2DConditionModel.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    subfolder="unet",
    revision="ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c"
)

Modify input layer to have 1 additional input channels (pose)

weights = unet.conv_in.weight.clone()
unet.conv_in = nn.Conv2d(4 + 2*n_poses, weights.shape[0], kernel_size=3, padding=(1, 1)) # input noise + n poses
with torch.no_grad():
    unet.conv_in.weight[:, :4] = weights # original weights
    unet.conv_in.weight[:, 3:] = torch.zeros(unet.conv_in.weight[:, 3:].shape) # new weights initialized to zero

the original shape of conv_in is 320, 4. so unet.conv_in.weight[:, 3:] = torch.zeros(unet.conv_in.weight[:, 3:].shape) would make the weights of dimension 3 become 0

dlutzzw commented 1 year ago

yes, i also notice this point. may be [:, 4:] is more resonable

zhangtao22 commented 1 year ago

@dlutzzw 这姑娘比较粗心 但demo结果到底是怎么得到的?

johannakarras commented 1 year ago

Thanks for the feedback, this bug is now fixed.