Describe the bug
There is a mismatch between the train.yaml configuration file and the loaded model weights (final.pt) when using the Wenet pretrained model wenetspeech_u2pp_conformer_exp. Specifically, when attempting to load the weights with the given configuration, several missing and unexpected keys are reported, which may indicate inconsistency between the model architecture defined in the YAML file and the actual pretrained weights.
Extract the downloaded archive on a Windows machine.
Update the train.yaml file to adjust the paths for units.txt and global_cmvn (the default path is not accommodated for the fact that the yaml file is under the same directory as units.txt and global_cmvn)
Use the following Python script to verify the consistency between train.yaml and final.pt:
import torch
import yaml
from wenet.transformer.asr_model import ASRModel
from wenet.transformer.encoder import ConformerEncoder
from wenet.transformer.decoder import TransformerDecoder
from wenet.transformer.ctc import CTC
import os
model_dir = "C:/path/to/model_directory"
# Load the configuration file with UTF-8 encoding
config_file = os.path.join(model_dir, "train.yaml")
with open(config_file, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# Extract relevant parameters from the config
vocab_size = config['output_dim']
encoder_conf = config['encoder_conf']
decoder_conf = config['decoder_conf']
# Remove unsupported parameters for TransformerDecoder
unsupported_decoder_keys = {'r_num_blocks'}
filtered_decoder_conf = {k: v for k, v in decoder_conf.items() if k not in unsupported_decoder_keys}
# Initialize Encoder and Decoder
encoder = ConformerEncoder(input_size=80, **encoder_conf)
decoder = TransformerDecoder(vocab_size=vocab_size, encoder_output_size=encoder_conf['output_size'], **filtered_decoder_conf)
ctc = CTC(odim=vocab_size, encoder_output_size=encoder_conf['output_size'])
# Initialize the ASR model
model = ASRModel(vocab_size=vocab_size, encoder=encoder, decoder=decoder, ctc=ctc, **config['model_conf'])
# Load pretrained model weights
checkpoint_path = os.path.join(model_dir, "final.pt")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Load weights and obtain missing and unexpected keys
load_result = model.load_state_dict(checkpoint, strict=False)
# Print missing and unexpected keys
print(f"Missing keys: {load_result.missing_keys}")
print(f"Unexpected keys: {load_result.unexpected_keys}")
Expected behavior
The script should load the pretrained weights without any missing or unexpected keys, indicating a consistent configuration between train.yaml and final.pt.
Screenshots
When I ran the code above, I can print the missing keys and unexpected keys below:
Smartphone (please complete the following information):
N/A
Additional context
The missing keys include parameters such as 'encoder.embed.pos_enc.pe', 'decoder.embed.0.weight', and other weights related to the decoder structure.
The unexpected keys include parameters like 'encoder.global_cmvn.mean', 'decoder.left_decoder.embed.0.weight', and others.
I am concerned that these mismatches may affect the fine-tuning process, as my goal is to use this model as a base for further training.
I have checked the README but couldn't find a way to resolve this issue by modifying the train.yaml file. I am wondering if there is an updated version of the model or if this is a known issue that can be safely ignored.
Any guidance or suggestions would be greatly appreciated, especially if there's a way to download a compatible version of the pretrained model.
Describe the bug There is a mismatch between the
train.yaml
configuration file and the loaded model weights (final.pt
) when using the Wenet pretrained modelwenetspeech_u2pp_conformer_exp
. Specifically, when attempting to load the weights with the given configuration, several missing and unexpected keys are reported, which may indicate inconsistency between the model architecture defined in the YAML file and the actual pretrained weights.To Reproduce Steps to reproduce the behavior:
wenetspeech_u2pp_conformer_exp.tar.gz
model from Wenet pretrained models page(https://wenet.org.cn/wenet/pretrained_models.en.html).train.yaml
file to adjust the paths forunits.txt
andglobal_cmvn
(the default path is not accommodated for the fact that the yaml file is under the same directory asunits.txt
andglobal_cmvn
)train.yaml
andfinal.pt
:Expected behavior The script should load the pretrained weights without any missing or unexpected keys, indicating a consistent configuration between
train.yaml
andfinal.pt
.Screenshots When I ran the code above, I can print the missing keys and unexpected keys below:
Missing keys: ['encoder.embed.pos_enc.pe', 'decoder.embed.0.weight', 'decoder.embed.1.pe', 'decoder.after_norm.weight', 'decoder.after_norm.bias', 'decoder.output_layer.weight', 'decoder.output_layer.bias', 'decoder.decoders.0.self_attn.linear_q.weight', 'decoder.decoders.0.self_attn.linear_q.bias', 'decoder.decoders.0.self_attn.linear_k.weight', 'decoder.decoders.0.self_attn.linear_k.bias', 'decoder.decoders.0.self_attn.linear_v.weight', 'decoder.decoders.0.self_attn.linear_v.bias', 'decoder.decoders.0.self_attn.linear_out.weight', 'decoder.decoders.0.self_attn.linear_out.bias', 'decoder.decoders.0.src_attn.linear_q.weight', 'decoder.decoders.0.src_attn.linear_q.bias', 'decoder.decoders.0.src_attn.linear_k.weight', 'decoder.decoders.0.src_attn.linear_k.bias', 'decoder.decoders.0.src_attn.linear_v.weight', 'decoder.decoders.0.src_attn.linear_v.bias', 'decoder.decoders.0.src_attn.linear_out.weight', 'decoder.decoders.0.src_attn.linear_out.bias', 'decoder.decoders.0.feed_forward.w_1.weight', 'decoder.decoders.0.feed_forward.w_1.bias', 'decoder.decoders.0.feed_forward.w_2.weight', 'decoder.decoders.0.feed_forward.w_2.bias', 'decoder.decoders.0.norm1.weight', 'decoder.decoders.0.norm1.bias', 'decoder.decoders.0.norm2.weight', 'decoder.decoders.0.norm2.bias', 'decoder.decoders.0.norm3.weight', 'decoder.decoders.0.norm3.bias', 'decoder.decoders.1.self_attn.linear_q.weight', 'decoder.decoders.1.self_attn.linear_q.bias', 'decoder.decoders.1.self_attn.linear_k.weight', 'decoder.decoders.1.self_attn.linear_k.bias', 'decoder.decoders.1.self_attn.linear_v.weight', 'decoder.decoders.1.self_attn.linear_v.bias', 'decoder.decoders.1.self_attn.linear_out.weight', 'decoder.decoders.1.self_attn.linear_out.bias', 'decoder.decoders.1.src_attn.linear_q.weight', 'decoder.decoders.1.src_attn.linear_q.bias', 'decoder.decoders.1.src_attn.linear_k.weight', 'decoder.decoders.1.src_attn.linear_k.bias', 'decoder.decoders.1.src_attn.linear_v.weight', 'decoder.decoders.1.src_attn.linear_v.bias', 'decoder.decoders.1.src_attn.linear_out.weight', 'decoder.decoders.1.src_attn.linear_out.bias', 'decoder.decoders.1.feed_forward.w_1.weight', 'decoder.decoders.1.feed_forward.w_1.bias', 'decoder.decoders.1.feed_forward.w_2.weight', 'decoder.decoders.1.feed_forward.w_2.bias', 'decoder.decoders.1.norm1.weight', 'decoder.decoders.1.norm1.bias', 'decoder.decoders.1.norm2.weight', 'decoder.decoders.1.norm2.bias', 'decoder.decoders.1.norm3.weight', 'decoder.decoders.1.norm3.bias', 'decoder.decoders.2.self_attn.linear_q.weight', 'decoder.decoders.2.self_attn.linear_q.bias', 'decoder.decoders.2.self_attn.linear_k.weight', 'decoder.decoders.2.self_attn.linear_k.bias', 'decoder.decoders.2.self_attn.linear_v.weight', 'decoder.decoders.2.self_attn.linear_v.bias', 'decoder.decoders.2.self_attn.linear_out.weight', 'decoder.decoders.2.self_attn.linear_out.bias', 'decoder.decoders.2.src_attn.linear_q.weight', 'decoder.decoders.2.src_attn.linear_q.bias', 'decoder.decoders.2.src_attn.linear_k.weight', 'decoder.decoders.2.src_attn.linear_k.bias', 'decoder.decoders.2.src_attn.linear_v.weight', 'decoder.decoders.2.src_attn.linear_v.bias', 'decoder.decoders.2.src_attn.linear_out.weight', 'decoder.decoders.2.src_attn.linear_out.bias', 'decoder.decoders.2.feed_forward.w_1.weight', 'decoder.decoders.2.feed_forward.w_1.bias', 'decoder.decoders.2.feed_forward.w_2.weight', 'decoder.decoders.2.feed_forward.w_2.bias', 'decoder.decoders.2.norm1.weight', 'decoder.decoders.2.norm1.bias', 'decoder.decoders.2.norm2.weight', 'decoder.decoders.2.norm2.bias', 'decoder.decoders.2.norm3.weight', 'decoder.decoders.2.norm3.bias'] Unexpected keys: ['encoder.global_cmvn.mean', 'encoder.global_cmvn.istd', 'decoder.left_decoder.embed.0.weight', 'decoder.left_decoder.after_norm.weight', 'decoder.left_decoder.after_norm.bias', 'decoder.left_decoder.output_layer.weight', 'decoder.left_decoder.output_layer.bias', 'decoder.left_decoder.decoders.0.self_attn.linear_q.weight', 'decoder.left_decoder.decoders.0.self_attn.linear_q.bias', 'decoder.left_decoder.decoders.0.self_attn.linear_k.weight', 'decoder.left_decoder.decoders.0.self_attn.linear_k.bias', 'decoder.left_decoder.decoders.0.self_attn.linear_v.weight', 'decoder.left_decoder.decoders.0.self_attn.linear_v.bias', 'decoder.left_decoder.decoders.0.self_attn.linear_out.weight', 'decoder.left_decoder.decoders.0.self_attn.linear_out.bias', 'decoder.left_decoder.decoders.0.src_attn.linear_q.weight', 'decoder.left_decoder.decoders.0.src_attn.linear_q.bias', 'decoder.left_decoder.decoders.0.src_attn.linear_k.weight', 'decoder.left_decoder.decoders.0.src_attn.linear_k.bias', 'decoder.left_decoder.decoders.0.src_attn.linear_v.weight', 'decoder.left_decoder.decoders.0.src_attn.linear_v.bias', 'decoder.left_decoder.decoders.0.src_attn.linear_out.weight', 'decoder.left_decoder.decoders.0.src_attn.linear_out.bias', 'decoder.left_decoder.decoders.0.feed_forward.w_1.weight', 'decoder.left_decoder.decoders.0.feed_forward.w_1.bias', 'decoder.left_decoder.decoders.0.feed_forward.w_2.weight', 'decoder.left_decoder.decoders.0.feed_forward.w_2.bias', 'decoder.left_decoder.decoders.0.norm1.weight', 'decoder.left_decoder.decoders.0.norm1.bias', 'decoder.left_decoder.decoders.0.norm2.weight', 'decoder.left_decoder.decoders.0.norm2.bias', 'decoder.left_decoder.decoders.0.norm3.weight', 'decoder.left_decoder.decoders.0.norm3.bias', 'decoder.left_decoder.decoders.1.self_attn.linear_q.weight', 'decoder.left_decoder.decoders.1.self_attn.linear_q.bias', 'decoder.left_decoder.decoders.1.self_attn.linear_k.weight', 'decoder.left_decoder.decoders.1.self_attn.linear_k.bias', 'decoder.left_decoder.decoders.1.self_attn.linear_v.weight', 'decoder.left_decoder.decoders.1.self_attn.linear_v.bias', 'decoder.left_decoder.decoders.1.self_attn.linear_out.weight', 'decoder.left_decoder.decoders.1.self_attn.linear_out.bias', 'decoder.left_decoder.decoders.1.src_attn.linear_q.weight', 'decoder.left_decoder.decoders.1.src_attn.linear_q.bias', 'decoder.left_decoder.decoders.1.src_attn.linear_k.weight', 'decoder.left_decoder.decoders.1.src_attn.linear_k.bias', 'decoder.left_decoder.decoders.1.src_attn.linear_v.weight', 'decoder.left_decoder.decoders.1.src_attn.linear_v.bias', 'decoder.left_decoder.decoders.1.src_attn.linear_out.weight', 'decoder.left_decoder.decoders.1.src_attn.linear_out.bias', 'decoder.left_decoder.decoders.1.feed_forward.w_1.weight', 'decoder.left_decoder.decoders.1.feed_forward.w_1.bias', 'decoder.left_decoder.decoders.1.feed_forward.w_2.weight', 'decoder.left_decoder.decoders.1.feed_forward.w_2.bias', 'decoder.left_decoder.decoders.1.norm1.weight', 'decoder.left_decoder.decoders.1.norm1.bias', 'decoder.left_decoder.decoders.1.norm2.weight', 'decoder.left_decoder.decoders.1.norm2.bias', 'decoder.left_decoder.decoders.1.norm3.weight', 'decoder.left_decoder.decoders.1.norm3.bias', 'decoder.left_decoder.decoders.2.self_attn.linear_q.weight', 'decoder.left_decoder.decoders.2.self_attn.linear_q.bias', 'decoder.left_decoder.decoders.2.self_attn.linear_k.weight', 'decoder.left_decoder.decoders.2.self_attn.linear_k.bias', 'decoder.left_decoder.decoders.2.self_attn.linear_v.weight', 'decoder.left_decoder.decoders.2.self_attn.linear_v.bias', 'decoder.left_decoder.decoders.2.self_attn.linear_out.weight', 'decoder.left_decoder.decoders.2.self_attn.linear_out.bias', 'decoder.left_decoder.decoders.2.src_attn.linear_q.weight', 'decoder.left_decoder.decoders.2.src_attn.linear_q.bias', 'decoder.left_decoder.decoders.2.src_attn.linear_k.weight', 'decoder.left_decoder.decoders.2.src_attn.linear_k.bias', 'decoder.left_decoder.decoders.2.src_attn.linear_v.weight', 'decoder.left_decoder.decoders.2.src_attn.linear_v.bias', 'decoder.left_decoder.decoders.2.src_attn.linear_out.weight', 'decoder.left_decoder.decoders.2.src_attn.linear_out.bias', 'decoder.left_decoder.decoders.2.feed_forward.w_1.weight', 'decoder.left_decoder.decoders.2.feed_forward.w_1.bias', 'decoder.left_decoder.decoders.2.feed_forward.w_2.weight', 'decoder.left_decoder.decoders.2.feed_forward.w_2.bias', 'decoder.left_decoder.decoders.2.norm1.weight', 'decoder.left_decoder.decoders.2.norm1.bias', 'decoder.left_decoder.decoders.2.norm2.weight', 'decoder.left_decoder.decoders.2.norm2.bias', 'decoder.left_decoder.decoders.2.norm3.weight', 'decoder.left_decoder.decoders.2.norm3.bias', 'decoder.right_decoder.embed.0.weight', 'decoder.right_decoder.after_norm.weight', 'decoder.right_decoder.after_norm.bias', 'decoder.right_decoder.output_layer.weight', 'decoder.right_decoder.output_layer.bias', 'decoder.right_decoder.decoders.0.self_attn.linear_q.weight', 'decoder.right_decoder.decoders.0.self_attn.linear_q.bias', 'decoder.right_decoder.decoders.0.self_attn.linear_k.weight', 'decoder.right_decoder.decoders.0.self_attn.linear_k.bias', 'decoder.right_decoder.decoders.0.self_attn.linear_v.weight', 'decoder.right_decoder.decoders.0.self_attn.linear_v.bias', 'decoder.right_decoder.decoders.0.self_attn.linear_out.weight', 'decoder.right_decoder.decoders.0.self_attn.linear_out.bias', 'decoder.right_decoder.decoders.0.src_attn.linear_q.weight', 'decoder.right_decoder.decoders.0.src_attn.linear_q.bias', 'decoder.right_decoder.decoders.0.src_attn.linear_k.weight', 'decoder.right_decoder.decoders.0.src_attn.linear_k.bias', 'decoder.right_decoder.decoders.0.src_attn.linear_v.weight', 'decoder.right_decoder.decoders.0.src_attn.linear_v.bias', 'decoder.right_decoder.decoders.0.src_attn.linear_out.weight', 'decoder.right_decoder.decoders.0.src_attn.linear_out.bias', 'decoder.right_decoder.decoders.0.feed_forward.w_1.weight', 'decoder.right_decoder.decoders.0.feed_forward.w_1.bias', 'decoder.right_decoder.decoders.0.feed_forward.w_2.weight', 'decoder.right_decoder.decoders.0.feed_forward.w_2.bias', 'decoder.right_decoder.decoders.0.norm1.weight', 'decoder.right_decoder.decoders.0.norm1.bias', 'decoder.right_decoder.decoders.0.norm2.weight', 'decoder.right_decoder.decoders.0.norm2.bias', 'decoder.right_decoder.decoders.0.norm3.weight', 'decoder.right_decoder.decoders.0.norm3.bias', 'decoder.right_decoder.decoders.1.self_attn.linear_q.weight', 'decoder.right_decoder.decoders.1.self_attn.linear_q.bias', 'decoder.right_decoder.decoders.1.self_attn.linear_k.weight', 'decoder.right_decoder.decoders.1.self_attn.linear_k.bias', 'decoder.right_decoder.decoders.1.self_attn.linear_v.weight', 'decoder.right_decoder.decoders.1.self_attn.linear_v.bias', 'decoder.right_decoder.decoders.1.self_attn.linear_out.weight', 'decoder.right_decoder.decoders.1.self_attn.linear_out.bias', 'decoder.right_decoder.decoders.1.src_attn.linear_q.weight', 'decoder.right_decoder.decoders.1.src_attn.linear_q.bias', 'decoder.right_decoder.decoders.1.src_attn.linear_k.weight', 'decoder.right_decoder.decoders.1.src_attn.linear_k.bias', 'decoder.right_decoder.decoders.1.src_attn.linear_v.weight', 'decoder.right_decoder.decoders.1.src_attn.linear_v.bias', 'decoder.right_decoder.decoders.1.src_attn.linear_out.weight', 'decoder.right_decoder.decoders.1.src_attn.linear_out.bias', 'decoder.right_decoder.decoders.1.feed_forward.w_1.weight', 'decoder.right_decoder.decoders.1.feed_forward.w_1.bias', 'decoder.right_decoder.decoders.1.feed_forward.w_2.weight', 'decoder.right_decoder.decoders.1.feed_forward.w_2.bias', 'decoder.right_decoder.decoders.1.norm1.weight', 'decoder.right_decoder.decoders.1.norm1.bias', 'decoder.right_decoder.decoders.1.norm2.weight', 'decoder.right_decoder.decoders.1.norm2.bias', 'decoder.right_decoder.decoders.1.norm3.weight', 'decoder.right_decoder.decoders.1.norm3.bias', 'decoder.right_decoder.decoders.2.self_attn.linear_q.weight', 'decoder.right_decoder.decoders.2.self_attn.linear_q.bias', 'decoder.right_decoder.decoders.2.self_attn.linear_k.weight', 'decoder.right_decoder.decoders.2.self_attn.linear_k.bias', 'decoder.right_decoder.decoders.2.self_attn.linear_v.weight', 'decoder.right_decoder.decoders.2.self_attn.linear_v.bias', 'decoder.right_decoder.decoders.2.self_attn.linear_out.weight', 'decoder.right_decoder.decoders.2.self_attn.linear_out.bias', 'decoder.right_decoder.decoders.2.src_attn.linear_q.weight', 'decoder.right_decoder.decoders.2.src_attn.linear_q.bias', 'decoder.right_decoder.decoders.2.src_attn.linear_k.weight', 'decoder.right_decoder.decoders.2.src_attn.linear_k.bias', 'decoder.right_decoder.decoders.2.src_attn.linear_v.weight', 'decoder.right_decoder.decoders.2.src_attn.linear_v.bias', 'decoder.right_decoder.decoders.2.src_attn.linear_out.weight', 'decoder.right_decoder.decoders.2.src_attn.linear_out.bias', 'decoder.right_decoder.decoders.2.feed_forward.w_1.weight', 'decoder.right_decoder.decoders.2.feed_forward.w_1.bias', 'decoder.right_decoder.decoders.2.feed_forward.w_2.weight', 'decoder.right_decoder.decoders.2.feed_forward.w_2.bias', 'decoder.right_decoder.decoders.2.norm1.weight', 'decoder.right_decoder.decoders.2.norm1.bias', 'decoder.right_decoder.decoders.2.norm2.weight', 'decoder.right_decoder.decoders.2.norm2.bias', 'decoder.right_decoder.decoders.2.norm3.weight', 'decoder.right_decoder.decoders.2.norm3.bias']
Desktop (please complete the following information):
Smartphone (please complete the following information): N/A
Additional context
'encoder.embed.pos_enc.pe'
,'decoder.embed.0.weight'
, and other weights related to thedecoder
structure.'encoder.global_cmvn.mean'
,'decoder.left_decoder.embed.0.weight'
, and others.train.yaml
file. I am wondering if there is an updated version of the model or if this is a known issue that can be safely ignored.Thank you!