Weizhi-Zhong / IP_LAP

CVPR2023 talking face implementation for Identity-Preserving Talking Face Generation With Landmark and Appearance Priors
Apache License 2.0
637 stars 72 forks source link

load trained_checkpoint error #9

Closed yaleimeng closed 1 year ago

yaleimeng commented 1 year ago

训练完landmark 和render模型执行inference_single.py推理时,提示加载checkpoint错误,状态字典缺少一些keys。信息如下:

landmark_generator_model loaded from : checkpoints/landmark_generation/Pro_landmarkT5_d512_fe1024_lay4_head4/landmarkT5_d512_fe1024_lay4_head4_epoch_2020_checkpoint_step000012120.pth renderer loaded from : checkpoints/renderer/Pro_renderer_T1_ref_N3/renderer_T1_ref_N3_epoch_7000_checkpoint_step000042000.pth Load checkpoint from: checkpoints/landmark_generation/Pro_landmarkT5_d512_fe1024_lay4_head4/landmarkT5_d512_fe1024_lay4_head4_epoch_2020_checkpoint_step000012120.pth --local/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( --local/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or None for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing weights=VGG19_Weights.IMAGENET1K_V1. You can also use weights=VGG19_Weights.DEFAULT to get the most up-to-date weights. warnings.warn(msg)

Perceptual loss: Mode: vgg19 Load checkpoint from: checkpoints/renderer/Pro_renderer_T1_ref_N3/renderer_T1_ref_N3_epoch_7000_checkpoint_step000042000.pth Traceback (most recent call last): File "IP_LAP/inference_single.py", line 194, in renderer = load_model(model=Renderer(), path=renderer_checkpoint_path) File "IP_LAP/inference_single.py", line 173, in load_model model.load_state_dict(new_s) File "local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Renderer: Missing key(s) in state_dict: "flow_module.conv1.weight", "flow_module.conv1.bias", "flow_module.conv1_bn.weight", "flow_module.conv1_bn.bias", "flow_module.conv1_bn.running_mean", "flow_module.conv1_bn.running_var", "flow_module.conv2.weight", "flow_module.conv2.bias", "flow_module.conv2_bn.weight", "flow_module.conv2_bn.bias", "flow_module.conv2_bn.running_mean", "flow_module.conv2_bn.running_var", "flow_module.spade_layer_1.conv_1.weight", "flow_module.spade_layer_1.conv_1.bias", "flow_module.spade_layer_1.conv_2.weight", "flow_module.spade_layer_1.conv_2.bias", "flow_module.spade_layer_1.spade_layer_1.conv1.weight", "flow_module.spade_layer_1.spade_layer_1.conv1.bias", "flow_module.spade_layer_1.spade_layer_1.gamma.weight", "flow_module.spade_layer_1.spade_layer_1.gamma.bias", "flow_module.spade_layer_1.spade_layer_1.beta.weight", "flow_module.spade_layer_1.spade_layer_1.beta.bias", "flow_module.spade_layer_1.spade_layer_2.conv1.weight", "flow_module.spade_layer_1.spade_layer_2.conv1.bias", "flow_module.spade_layer_1.spade_layer_2.gamma.weight", "flow_module.spade_layer_1.spade_layer_2.gamma.bias", "flow_module.spade_layer_1.spade_layer_2.beta.weight", "flow_module.spade_layer_1.spade_layer_2.beta.bias", "flow_module.spade_layer_2.conv_1.weight", "flow_module.spade_layer_2.conv_1.bias", "flow_module.spade_layer_2.conv_2.weight", "flow_module.spade_layer_2.conv_2.bias", "flow_module.spade_layer_2.spade_layer_1.conv1.weight", "flow_module.spade_layer_2.spade_layer_1.conv1.bias", "flow_module.spade_layer_2.spade_layer_1.gamma.weight", "flow_module.spade_layer_2.spade_layer_1.gamma.bias", "flow_module.spade_layer_2.spade_layer_1.beta.weight", "flow_module.spade_layer_2.spade_layer_1.beta.bias", "flow_module.spade_layer_2.spade_layer_2.conv1.weight", "flow_module.spade_layer_2.spade_layer_2.conv1.bias", "flow_module.spade_layer_2.spade_layer_2.gamma.weight", "flow_module.spade_layer_2.spade_layer_2.gamma.bias", "flow_module.spade_layer_2.spade_layer_2.beta.weight", "flow_module.spade_layer_2.spade_layer_2.beta.bias", "flow_module.spade_layer_4.conv_1.weight", "flow_module.spade_layer_4.conv_1.bias", "flow_module.spade_layer_4.conv_2.weight", "flow_module.spade_layer_4.conv_2.bias", "flow_module.spade_layer_4.spade_layer_1.conv1.weight", "flow_module.spade_layer_4.spade_layer_1.conv1.bias", "flow_module.spade_layer_4.spade_layer_1.gamma.weight", "flow_module.spade_layer_4.spade_layer_1.gamma.bias", "flow_module.spade_layer_4.spade_layer_1.beta.weight", "flow_module.spade_layer_4.spade_layer_1.beta.bias", "flow_module.spade_layer_4.spade_layer_2.conv1.weight", "flow_module.spade_layer_4.spade_layer_2.conv1.bias", "flow_module.spade_layer_4.spade_layer_2.gamma.weight", "flow_module.spade_layer_4.spade_layer_2.gamma.bias", "flow_module.spade_layer_4.spade_layer_2.beta.weight", "flow_module.spade_layer_4.spade_layer_2.beta.bias", "flow_module.conv_4.weight", "flow_module.conv_4.bias", "flow_module.conv_5.0.weight", "flow_module.conv_5.0.bias", "flow_module.conv_5.2.weight", "flow_module.conv_5.2.bias".

Weizhi-Zhong commented 1 year ago

@yaleimeng Hi~, thanks for your interest, and sorry for the bug. The problem may be related to the line "new_s[k.replace('module.', '', 1)] = v" in inference_single.py and "self.flow_module = DenseFlowNetwork()" in video_renderer.py

I guess you train the render with a single one gpu. Try to replace it with the following code in inference_single.py. for k, v in s.items(): if k[:6]=='module': new_s[k.replace('module.', '', 1)] = v And can you tell me whether it works after you try it? Thank you very much.

yaleimeng commented 1 year ago

确实是使用单GPU训练的。 改了之后,render模型加载正常了,但是landmark加载又报错,状态字典很多key找不到。 Load checkpoint from: checkpoints/landmark_generation/Pro_landmarkT5_d512_fe1024_lay4_head4/landmarkT5_d512_fe1024_lay4_head4_epoch_2020_checkpoint_step000012120.pth Traceback (most recent call last): File "/IP_LAP/inference_single.py", line 193, in landmark_generator_model = load_model( File "/IP_LAP/inference_single.py", line 175, in load_model model.load_state_dict(new_s) File "/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Landmark_generator: Missing key(s) in state_dict: "mel_encoder.0.conv_block.0.weight", "mel_encoder.0.conv_block.0.bias", "mel_encoder.0.conv_block.1.weight", "mel_encoder.0.conv_block.1.bias", "mel_encoder.0.conv_block.1.running_mean", "mel_encoder.0.conv_block.1.running_var", "mel_encoder.1.conv_block.0.weight", "mel_encoder.1.conv_block.0.bias", "mel_encoder.1.conv_block.1.weight", "mel_encoder.1.conv_block.1.bias", "mel_encoder.1.conv_block.1.running_mean", "mel_encoder.1.conv_block.1.running_var", "mel_encoder.2.conv_block.0.weight", "mel_encoder.2.conv_block.0.bias", "mel_encoder.2.conv_block.1.weight", "mel_encoder.2.conv_block.1.bias", "mel_encoder.2.conv_block.1.running_mean", "mel_encoder.2.conv_block.1.running_var", "mel_encoder.3.conv_block.0.weight", "mel_encoder.3.conv_block.0.bias", "mel_encoder.3.conv_block.1.weight", "mel_encoder.3.conv_block.1.bias", "mel_encoder.3.conv_block.1.running_mean", "mel_encoder.3.conv_block.1.running_var", "mel_encoder.4.conv_block.0.weight", "mel_encoder.4.conv_block.0.bias", *****还有很多。省略

Weizhi-Zhong commented 1 year ago

@yaleimeng Sorry for my negligence. Try to replace it with the following code in inference_single.py.

def load_model(model, path):
    print("Load checkpoint from: {}".format(path))
    checkpoint = _load(path)
    s = checkpoint["state_dict"]
    new_s = {}
    for k, v in s.items():
        if k[:6] == 'module':
            new_k=k.replace('module.', '', 1)
        else:
            new_k =k
        new_s[new_k] = v
    model.load_state_dict(new_s)
    model = model.to(device)
    return model.eval()

Also, can you tell me whether it works after you try it? Thank you very much.

yaleimeng commented 1 year ago

Thanks, It works 。

Weizhi-Zhong commented 1 year ago

Thank you very much~