Xflick / EEND_PyTorch

A PyTorch implementation of End-to-End Neural Diarization
MIT License
98 stars 16 forks source link

Inference on pretrained model #13

Closed AlirezaMorsali closed 1 year ago

AlirezaMorsali commented 1 year ago

Thank you for open sourcing your work. I'm trying to familiarize myself with the repo and tried the inference with the provided pretrained model. I get error when loading the model. I would greatly appreciate your help with this issue.

Here are the arguments I pass to infer.py:

                "${workspaceFolder}/conf/large/infer.yaml",
                "${workspaceFolder}/dataset",
                "${workspaceFolder}/pretrained_models/large/model_callhome.th",
                "${workspaceFolder}/outpath",
                "--model_type",
                "Transformer",
                "--gpu",
                "-0",

And here is the error:

  File "/EEND_PyTorch/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1406, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for TransformerModel:
        Missing key(s) in state_dict: "encoder.weight", "encoder.bias", "encoder_norm.weight", "encoder_norm.bias", "transformer_encoder.layers.0.self_attn.in_proj_weight", "transformer_encoder.layers.0.self_attn.in_proj_bias", "transformer_encoder.layers.0.self_attn.out_proj.weight", "transformer_encoder.layers.0.self_attn.out_proj.bias", "transformer_encoder.layers.0.linear1.weight", "transformer_encoder.layers.0.linear1.bias", "transformer_encoder.layers.0.linear2.weight", "transformer_encoder.layers.0.linear2.bias", "transformer_encoder.layers.0.norm1.weight", "transformer_encoder.layers.0.norm1.bias", "transformer_encoder.layers.0.norm2.weight", "transformer_encoder.layers.0.norm2.bias", "transformer_encoder.layers.1.self_attn.in_proj_weight", "transformer_encoder.layers.1.self_attn.in_proj_bias", "transformer_encoder.layers.1.self_attn.out_proj.weight", "transformer_encoder.layers.1.self_attn.out_proj.bias", "transformer_encoder.layers.1.linear1.weight", "transformer_encoder.layers.1.linear1.bias", "transformer_encoder.layers.1.linear2.weight", "transformer_encoder.layers.1.linear2.bias", "transformer_encoder.layers.1.norm1.weight", "transformer_encoder.layers.1.norm1.bias", "transformer_encoder.layers.1.norm2.weight", "transformer_encoder.layers.1.norm2.bias", "transformer_encoder.layers.2.self_attn.in_proj_weight", "transformer_encoder.layers.2.self_attn.in_proj_bias", "transformer_encoder.layers.2.self_attn.out_proj.weight", "transformer_encoder.layers.2.self_attn.out_proj.bias", "transformer_encoder.layers.2.linear1.weight", "transformer_encoder.layers.2.linear1.bias", "transformer_encoder.layers.2.linear2.weight", "transformer_encoder.layers.2.linear2.bias", "transformer_encoder.layers.2.norm1.weight", "transformer_encoder.layers.2.norm1.bias", "transformer_encoder.layers.2.norm2.weight", "transformer_encoder.layers.2.norm2.bias", "transformer_encoder.layers.3.self_attn.in_proj_weight", "transformer_encoder.layers.3.self_attn.in_proj_bias", "transformer_encoder.layers.3.self_attn.out_proj.weight", "transformer_encoder.layers.3.self_attn.out_proj.bias", "transformer_encoder.layers.3.linear1.weight", "transformer_encoder.layers.3.linear1.bias", "transformer_encoder.layers.3.linear2.weight", "transformer_encoder.layers.3.linear2.bias", "transformer_encoder.layers.3.norm1.weight", "transformer_encoder.layers.3.norm1.bias", "transformer_encoder.layers.3.norm2.weight", "transformer_encoder.layers.3.norm2.bias", "decoder.weight", "decoder.bias". 
        Unexpected key(s) in state_dict: "module.encoder.weight", "module.encoder.bias", "module.encoder_norm.weight", "module.encoder_norm.bias", "module.transformer_encoder.layers.0.self_attn.in_proj_weight", "module.transformer_encoder.layers.0.self_attn.in_proj_bias", "module.transformer_encoder.layers.0.self_attn.out_proj.weight", "module.transformer_encoder.layers.0.self_attn.out_proj.bias", "module.transformer_encoder.layers.0.linear1.weight", "module.transformer_encoder.layers.0.linear1.bias", "module.transformer_encoder.layers.0.linear2.weight", "module.transformer_encoder.layers.0.linear2.bias", "module.transformer_encoder.layers.0.norm1.weight", "module.transformer_encoder.layers.0.norm1.bias", "module.transformer_encoder.layers.0.norm2.weight", "module.transformer_encoder.layers.0.norm2.bias", "module.transformer_encoder.layers.1.self_attn.in_proj_weight", "module.transformer_encoder.layers.1.self_attn.in_proj_bias", "module.transformer_encoder.layers.1.self_attn.out_proj.weight", "module.transformer_encoder.layers.1.self_attn.out_proj.bias", "module.transformer_encoder.layers.1.linear1.weight", "module.transformer_encoder.layers.1.linear1.bias", "module.transformer_encoder.layers.1.linear2.weight", "module.transformer_encoder.layers.1.linear2.bias", "module.transformer_encoder.layers.1.norm1.weight", "module.transformer_encoder.layers.1.norm1.bias", "module.transformer_encoder.layers.1.norm2.weight", "module.transformer_encoder.layers.1.norm2.bias", "module.transformer_encoder.layers.2.self_attn.in_proj_weight", "module.transformer_encoder.layers.2.self_attn.in_proj_bias", "module.transformer_encoder.layers.2.self_attn.out_proj.weight", "module.transformer_encoder.layers.2.self_attn.out_proj.bias", "module.transformer_encoder.layers.2.linear1.weight", "module.transformer_encoder.layers.2.linear1.bias", "module.transformer_encoder.layers.2.linear2.weight", "module.transformer_encoder.layers.2.linear2.bias", "module.transformer_encoder.layers.2.norm1.weight", "module.transformer_encoder.layers.2.norm1.bias", "module.transformer_encoder.layers.2.norm2.weight", "module.transformer_encoder.layers.2.norm2.bias", "module.transformer_encoder.layers.3.self_attn.in_proj_weight", "module.transformer_encoder.layers.3.self_attn.in_proj_bias", "module.transformer_encoder.layers.3.self_attn.out_proj.weight", "module.transformer_encoder.layers.3.self_attn.out_proj.bias", "module.transformer_encoder.layers.3.linear1.weight", "module.transformer_encoder.layers.3.linear1.bias", "module.transformer_encoder.layers.3.linear2.weight", "module.transformer_encoder.layers.3.linear2.bias", "module.transformer_encoder.layers.3.norm1.weight", "module.transformer_encoder.layers.3.norm1.bias", "module.transformer_encoder.layers.3.norm2.weight", "module.transformer_encoder.layers.3.norm2.bias", "module.decoder.weight", "module.decoder.bias".
Xflick commented 1 year ago

Oh, cpu inferring may not be supported. The checkpoint is trained and saved with nn.DataParallel, and when loading, if no gpu is specified, the code will not load the model with nn.DataParallel. A quick way to fix is to use gpu to infer. An alternative way is to remove module. in state_dict, like state_dict = {k.replace('module.',''): v for k,v in state_dict.items()}.

AlirezaMorsali commented 1 year ago

Thank you for response. I will try with GPU. In the meantime your suggestion fixed the issue with CPU. I will closing this issue.

Just as a suggestion, it would be very helpful to include your requirements.txt with library versions if possible as some libraries had breaking changes since this code was written.