johannakarras / DreamPose

Official implementation of "DreamPose: Fashion Image-to-Video Synthesis via Stable Diffusion"
MIT License
962 stars 73 forks source link

dict key error in demo #49

Open DmitriyPin opened 1 year ago

DmitriyPin commented 1 year ago

Hello, When I am running demo, I am getting the following dictionary key error. I am using PyTorch 2.0.1 and cuda 11.7. Is that specific version of PyTorch that I should use? Any ideas how I can resolve this? Thank you

Traceback (most recent call last):
  File "test.py", line 87, in <module>
    pipe.vae.load_state_dict(new_state_dict)
  File "C:\Users\best4\AppData\Roaming\Python\Python37\site-packages\torch\nn\modules\module.py", line 1672, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for AutoencoderKL:
        Missing key(s) in state_dict: "encoder.mid_block.attentions.0.to_q.weight", "encoder.mid_block.attentions.0.to_q.bias", "encoder.mid_block.attentions.0.to_k.weight", "encoder.mid_block.attentions.0.to_k.bias", "encoder.mid_block.attentions.0.to_v.weight", "encoder.mid_block.attentions.0.to_v.bias", "encoder.mid_block.attentions.0.to_out.0.weight", "encoder.mid_block.attentions.0.to_out.0.bias", "decoder.mid_block.attentions.0.to_q.weight", "decoder.mid_block.attentions.0.to_q.bias", "decoder.mid_block.attentions.0.to_k.weight", "decoder.mid_block.attentions.0.to_k.bias", "decoder.mid_block.attentions.0.to_v.weight", "decoder.mid_block.attentions.0.to_v.bias", "decoder.mid_block.attentions.0.to_out.0.weight", "decoder.mid_block.attentions.0.to_out.0.bias".
        Unexpected key(s) in state_dict: "encoder.mid_block.attentions.0.query.weight", "encoder.mid_block.attentions.0.query.bias", "encoder.mid_block.attentions.0.key.weight", "encoder.mid_block.attentions.0.key.bias", "encoder.mid_block.attentions.0.value.weight", "encoder.mid_block.attentions.0.value.bias", "encoder.mid_block.attentions.0.proj_attn.weight", "encoder.mid_block.attentions.0.proj_attn.bias", "decoder.mid_block.attentions.0.query.weight", "decoder.mid_block.attentions.0.query.bias", "decoder.mid_block.attentions.0.key.weight", "decoder.mid_block.attentions.0.key.bias", "decoder.mid_block.attentions.0.value.weight", "decoder.mid_block.attentions.0.value.bias", "decoder.mid_block.attentions.0.proj_attn.weight", "decoder.mid_block.attentions.0.proj_attn.bias".
VoHoangAnh commented 1 year ago

have you fixed it yet? I have the same problem.

DmitriyPin commented 1 year ago

no , I do not have a fix for this

hughkhu commented 1 year ago

I have the same problem.

VoHoangAnh commented 1 year ago

I replaced keywords and it worked.

LaiaTarres commented 1 year ago

Same! I have replaced it like this and now it doesn't give me this error anymore:

        name = name.replace('query.', 'to_q.')
        name = name.replace('key.', 'to_k.')
        name = name.replace('value.', 'to_v.')
        name = name.replace('proj_attn.', 'to_out.')
        name = name.replace('.mid_block.attentions.0.to_out.', '.mid_block.attentions.0.to_out.0.')
TrainColab commented 11 months ago

@LaiaTarres hi Laia can u plz help me to find, where should replace this keywords (file name + ligne) . and thank you so much

xavier111222 commented 9 months ago

@LaiaTarres Can you send any code that needs to be changed or replaced? I would love to be able to test it because it really amazes me zonqio@outlook.com Thank you

SaharaSheik commented 7 months ago

I fixed this by modifying the line (this is a common issues so all you need to do is toc change the state dict names such that they match your expected state:

    for k, v in vae_state_dict.items():
        name1 = k.replace('module.', '')  #name = k[7:] if k[:7] == 'module' else k
        name2 = name1.replace('query', 'to_q')  #name = k[7:] if k[:7] == 'module' else k
        name3 = name2.replace('key', 'to_k')
        name4 = name3.replace('value', 'to_v')
        name = name4.replace('proj_attn', 'to_out.0')
        new_state_dict[name] = v
    pipe.vae.load_state_dict(new_state_dict)
    pipe.vae = pipe.vae.cuda()