bshall / acoustic-model

Acoustic models for: A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion
https://bshall.github.io/soft-vc/
MIT License
100 stars 24 forks source link

Finetuned model while loading RuntimeError: Error(s) in loading state_dict for AcousticModel #10

Closed MuruganR96 closed 1 year ago

MuruganR96 commented 1 year ago

@bshall Thank you for this great work.

I did fine-tune the pre-trained acoustic LJSpeech model with my custom dataset (~ 1 hour).

python train.py --resume checkpoints/hubert-soft-0321fd7e.pt data/ finetuned_checkpoints/

I have newly fine-tuned the best model (model-best.pt) with 20000 steps. I modified the code (https://github.com/bshall/acoustic-model/blob/main/acoustic/model.py#L119). the loading from the torch.hub.load_state_dict_from_url to my checkpoint path. but I got the below error. I shared the error log for your reference.

can you please help me, how to resolve this issue?

Thanks

Traceback (most recent call last):
  File "/root/Experiments/soft-vc/inference.py", line 12, in <module>
    acoustic = hubert_soft().cuda()
  File "/root/Experiments/soft-vc/acoustic/acoustic/model.py", line 165, in hubert_soft
    return _acoustic(
  File "/root/Experiments/soft-vc/acoustic/acoustic/model.py", line 133, in _acoustic
    acoustic.load_state_dict(checkpoint["acoustic-model"])
  File "/root/anaconda3/envs/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1406, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for AcousticModel:
        Missing key(s) in state_dict: "encoder.prenet.net.0.weight", "encoder.prenet.net.0.bias", "encoder.prenet.net.3.weight", "encoder.prenet.net.3.bias", "encoder.convs.0.weight", "encoder.convs.0.bias", "encoder.convs.3.weight", "encoder.convs.3.bias", "encoder.convs.4.weight", "encoder.convs.4.bias", "encoder.convs.7.weight", "encoder.convs.7.bias", "decoder.prenet.net.0.weight", "decoder.prenet.net.0.bias", "decoder.prenet.net.3.weight", "decoder.prenet.net.3.bias", "decoder.lstm1.weight_ih_l0", "decoder.lstm1.weight_hh_l0", "decoder.lstm1.bias_ih_l0", "decoder.lstm1.bias_hh_l0", "decoder.lstm2.weight_ih_l0", "decoder.lstm2.weight_hh_l0", "decoder.lstm2.bias_ih_l0", "decoder.lstm2.bias_hh_l0", "decoder.lstm3.weight_ih_l0", "decoder.lstm3.weight_hh_l0", "decoder.lstm3.bias_ih_l0", "decoder.lstm3.bias_hh_l0", "decoder.proj.weight". 
        Unexpected key(s) in state_dict: "module.encoder.prenet.net.0.weight", "module.encoder.prenet.net.0.bias", "module.encoder.prenet.net.3.weight", "module.encoder.prenet.net.3.bias", "module.encoder.convs.0.weight", "module.encoder.convs.0.bias", "module.encoder.convs.3.weight", "module.encoder.convs.3.bias", "module.encoder.convs.4.weight", "module.encoder.convs.4.bias", "module.encoder.convs.7.weight", "module.encoder.convs.7.bias", "module.decoder.prenet.net.0.weight", "module.decoder.prenet.net.0.bias", "module.decoder.prenet.net.3.weight", "module.decoder.prenet.net.3.bias", "module.decoder.lstm1.weight_ih_l0", "module.decoder.lstm1.weight_hh_l0", "module.decoder.lstm1.bias_ih_l0", "module.decoder.lstm1.bias_hh_l0", "module.decoder.lstm2.weight_ih_l0", "module.decoder.lstm2.weight_hh_l0", "module.decoder.lstm2.bias_ih_l0", "module.decoder.lstm2.bias_hh_l0", "module.decoder.lstm3.weight_ih_l0", "module.decoder.lstm3.weight_hh_l0", "module.decoder.lstm3.bias_ih_l0", "module.decoder.lstm3.bias_hh_l0", "module.decoder.proj.weight". 
def _acoustic(
    name: str,
    discrete: bool,
    upsample: bool,
    pretrained: bool = True,
    progress: bool = True,
) -> AcousticModel:
    acoustic = AcousticModel(discrete, upsample)
    if pretrained:
        # checkpoint = torch.hub.load_state_dict_from_url(URLS[name], progress=progress)
        # consume_prefix_in_state_dict_if_present(checkpoint["acoustic-model"], "module.")

        load_path = "/root/Experiments/soft-vc/acoustic/finetuned_checkpoints/model-best.pt"
        checkpoint = torch.load(load_path)
        acoustic.load_state_dict(checkpoint["acoustic-model"])
        acoustic.eval()
    return acoustic 
MuruganR96 commented 1 year ago

I fixed the issue. it's my mistake. this change is worked

consume_prefix_in_state_dict_if_present(checkpoint["acoustic-model"], "module.")

def _acoustic(
    name: str,
    discrete: bool,
    upsample: bool,
    pretrained: bool = True,
    progress: bool = True,
) -> AcousticModel:
    acoustic = AcousticModel(discrete, upsample)
    if pretrained:
        # checkpoint = torch.hub.load_state_dict_from_url(URLS[name], progress=progress)
        checkpoint = torch.load("/root/Experiments/soft-vc/acoustic-model-0.1/finetuned_checkpoints/model-best.pt")
        consume_prefix_in_state_dict_if_present(checkpoint["acoustic-model"], "module.")
        acoustic.load_state_dict(checkpoint["acoustic-model"])
        acoustic.eval()
    return acoustic