Open zzk88862 opened 1 year ago
Hi, Fixed this error by reversing this commit. So changing this line: unet.conv_in.weight[:, 4:] = torch.zeros(unet.conv_in.weight[:, 3:].shape) For what was previously there: unet.conv_in.weight[:, 3:] = torch.zeros(unet.conv_in.weight[:, 3:].shape)
@LaiaTarres thanks for your method, i changed this line to unet.conv_in.weight[:, 3:] = torch.zeros(unet.conv_in.weight[:, 3:].shape),
but another error is happend,
how should i fix this error?
The workaround is in this other issue.
@LaiaTarres thanks for your method, i changed this line to unet.conv_in.weight[:, 3:] = torch.zeros(unet.conv_in.weight[:, 3:].shape),
but another error is happend,
how should i fix this error?
@zzk88862 I added it just after this line https://github.com/johannakarras/DreamPose/blob/main/test.py#L54
I fixed this by modifying the line (this is a common issues so all you need to do is toc change the state dict names such that they match your expected state:
for k, v in vae_state_dict.items():
name1 = k.replace('module.', '') #name = k[7:] if k[:7] == 'module' else k
name2 = name1.replace('query', 'to_q') #name = k[7:] if k[:7] == 'module' else k
name3 = name2.replace('key', 'to_k')
name4 = name3.replace('value', 'to_v')
name = name4.replace('proj_attn', 'to_out.0')
new_state_dict[name] = v
pipe.vae.load_state_dict(new_state_dict)
pipe.vae = pipe.vae.cuda()