Modify input layer to have 1 additional input channels (pose)
weights = unet.conv_in.weight.clone()
unet.conv_in = nn.Conv2d(4 + 2*n_poses, weights.shape[0], kernel_size=3, padding=(1, 1)) # input noise + n poses
with torch.no_grad():
unet.conv_in.weight[:, :4] = weights # original weights
unet.conv_in.weight[:, 3:] = torch.zeros(unet.conv_in.weight[:, 3:].shape) # new weights initialized to zero
the original shape of conv_in is 320, 4. so unet.conv_in.weight[:, 3:] = torch.zeros(unet.conv_in.weight[:, 3:].shape) would make the weights of dimension 3 become 0
def get_unet(pretrained_model_name_or_path, revision, resolution=256, n_poses=5):
Load pretrained UNet layers
Modify input layer to have 1 additional input channels (pose)
the original shape of conv_in is 320, 4. so unet.conv_in.weight[:, 3:] = torch.zeros(unet.conv_in.weight[:, 3:].shape) would make the weights of dimension 3 become 0