tanshuai0219 / EDTalk

[ECCV 2024 Oral] EDTalk - Official PyTorch Implementation
Apache License 2.0
344 stars 31 forks source link

How to continue training from train_E_G.py to train_Mouth_Pose.py? #35

Closed hbkang-slx closed 1 month ago

hbkang-slx commented 1 month ago

The README explains that train_E_G.py is used to pretrain the encoder and generator, but doesn't explain how the pretrained model will be used in train_Mouth_Pose.py or train_audio2mouth.py.

Is using the --resume_ckpt argument pointing at the pretrained checkpoints enough, even though the model architecture changes between the two codes? (e.g. train/networks/styledecoder.py vs train/networks_Lip_NonLip/styledecoder.py)

tanshuai0219 commented 1 month ago

The README explains that train_E_G.py is used to pretrain the encoder and generator, but doesn't explain how the pretrained model will be used in train_Mouth_Pose.py or train_audio2mouth.py.

Is using the --resume_ckpt argument pointing at the pretrained checkpoints enough, even though the model architecture changes between the two codes? (e.g. train/networks/styledecoder.py vs train/networks_Lip_NonLip/styledecoder.py)

Thanks for your interest. For your question, the answer is yes. Most of the parameters of the network structure in train/networks/styledecoder.py and train/networks_Lip_NonLip/styledecoder.py are the same, we only need to load those weight with the same parameters, for the different parameters, it is the part of the new code that needs to be trained

hbkang-slx commented 1 month ago

Do I need to manually change the weights stored in the checkpoints, or change the trainer_*.py code to handle missing module weights?

tanshuai0219 commented 1 month ago

Try to delete the # in https://github.com/tanshuai0219/EDTalk/blob/a375dbdec5128372cb62988c68baebdd6ed551a9/train/trainer_Mouth_Pose_decouple.py#L325 `

new_state_dict = OrderedDict()

    # for key, value in checkpoint.items():
    #     if 'enc.fc.' in key:
    #         if 'enc.fc.4' in key:
    #             continue
    #         name = key.split('enc.fc.')[1]
    #         new_state_dict[name] = value

    self.gen.load_state_dict(checkpoint)

    # new_state_dict = OrderedDict()
    # for key, value in checkpoint.items():
    #     if 'enc.net_app.' in key:
    #         name = key.split('enc.')[1]
    #         new_state_dict[name] = value
    # self.gen.enc.load_state_dict(new_state_dict)

    # new_state_dict = OrderedDict()
    # for key, value in checkpoint.items():
    #     if 'dec.' in key:
    #         if 'dec.direc' in key:
    #             continue
    #         name = key.split('dec.')[1]
    #         new_state_dict[name] = value
    # self.gen.dec.load_state_dict(new_state_dict)

`

hbkang-slx commented 1 month ago

Thank you for the explanation. Closed.