wenet-e2e / wenet

Production First and Production Ready End-to-End Speech Recognition Toolkit
https://wenet-e2e.github.io/wenet/
Apache License 2.0
4.17k stars 1.08k forks source link

Issue with Model and Config File Mismatch in Wenet Conformer #2645

Open yhuangece7 opened 3 weeks ago

yhuangece7 commented 3 weeks ago

Describe the bug There is a mismatch between the train.yaml configuration file and the loaded model weights (final.pt) when using the Wenet pretrained model wenetspeech_u2pp_conformer_exp. Specifically, when attempting to load the weights with the given configuration, several missing and unexpected keys are reported, which may indicate inconsistency between the model architecture defined in the YAML file and the actual pretrained weights.

To Reproduce Steps to reproduce the behavior:

  1. Download the wenetspeech_u2pp_conformer_exp.tar.gz model from Wenet pretrained models page(https://wenet.org.cn/wenet/pretrained_models.en.html).
  2. Extract the downloaded archive on a Windows machine.
  3. Update the train.yaml file to adjust the paths for units.txt and global_cmvn (the default path is not accommodated for the fact that the yaml file is under the same directory as units.txt and global_cmvn)
  4. Use the following Python script to verify the consistency between train.yaml and final.pt:
import torch
import yaml
from wenet.transformer.asr_model import ASRModel
from wenet.transformer.encoder import ConformerEncoder
from wenet.transformer.decoder import TransformerDecoder
from wenet.transformer.ctc import CTC
import os

model_dir = "C:/path/to/model_directory"

# Load the configuration file with UTF-8 encoding
config_file = os.path.join(model_dir, "train.yaml")
with open(config_file, 'r', encoding='utf-8') as f:
    config = yaml.safe_load(f)

# Extract relevant parameters from the config
vocab_size = config['output_dim']
encoder_conf = config['encoder_conf']
decoder_conf = config['decoder_conf']

# Remove unsupported parameters for TransformerDecoder
unsupported_decoder_keys = {'r_num_blocks'}
filtered_decoder_conf = {k: v for k, v in decoder_conf.items() if k not in unsupported_decoder_keys}

# Initialize Encoder and Decoder
encoder = ConformerEncoder(input_size=80, **encoder_conf)
decoder = TransformerDecoder(vocab_size=vocab_size, encoder_output_size=encoder_conf['output_size'], **filtered_decoder_conf)
ctc = CTC(odim=vocab_size, encoder_output_size=encoder_conf['output_size'])

# Initialize the ASR model
model = ASRModel(vocab_size=vocab_size, encoder=encoder, decoder=decoder, ctc=ctc, **config['model_conf'])

# Load pretrained model weights
checkpoint_path = os.path.join(model_dir, "final.pt")
checkpoint = torch.load(checkpoint_path, map_location="cpu")

# Load weights and obtain missing and unexpected keys
load_result = model.load_state_dict(checkpoint, strict=False)

# Print missing and unexpected keys
print(f"Missing keys: {load_result.missing_keys}")
print(f"Unexpected keys: {load_result.unexpected_keys}")

Expected behavior The script should load the pretrained weights without any missing or unexpected keys, indicating a consistent configuration between train.yaml and final.pt.

Screenshots When I ran the code above, I can print the missing keys and unexpected keys below:

Missing keys: ['encoder.embed.pos_enc.pe', 'decoder.embed.0.weight', 'decoder.embed.1.pe', 'decoder.after_norm.weight', 'decoder.after_norm.bias', 'decoder.output_layer.weight', 'decoder.output_layer.bias', 'decoder.decoders.0.self_attn.linear_q.weight', 'decoder.decoders.0.self_attn.linear_q.bias', 'decoder.decoders.0.self_attn.linear_k.weight', 'decoder.decoders.0.self_attn.linear_k.bias', 'decoder.decoders.0.self_attn.linear_v.weight', 'decoder.decoders.0.self_attn.linear_v.bias', 'decoder.decoders.0.self_attn.linear_out.weight', 'decoder.decoders.0.self_attn.linear_out.bias', 'decoder.decoders.0.src_attn.linear_q.weight', 'decoder.decoders.0.src_attn.linear_q.bias', 'decoder.decoders.0.src_attn.linear_k.weight', 'decoder.decoders.0.src_attn.linear_k.bias', 'decoder.decoders.0.src_attn.linear_v.weight', 'decoder.decoders.0.src_attn.linear_v.bias', 'decoder.decoders.0.src_attn.linear_out.weight', 'decoder.decoders.0.src_attn.linear_out.bias', 'decoder.decoders.0.feed_forward.w_1.weight', 'decoder.decoders.0.feed_forward.w_1.bias', 'decoder.decoders.0.feed_forward.w_2.weight', 'decoder.decoders.0.feed_forward.w_2.bias', 'decoder.decoders.0.norm1.weight', 'decoder.decoders.0.norm1.bias', 'decoder.decoders.0.norm2.weight', 'decoder.decoders.0.norm2.bias', 'decoder.decoders.0.norm3.weight', 'decoder.decoders.0.norm3.bias', 'decoder.decoders.1.self_attn.linear_q.weight', 'decoder.decoders.1.self_attn.linear_q.bias', 'decoder.decoders.1.self_attn.linear_k.weight', 'decoder.decoders.1.self_attn.linear_k.bias', 'decoder.decoders.1.self_attn.linear_v.weight', 'decoder.decoders.1.self_attn.linear_v.bias', 'decoder.decoders.1.self_attn.linear_out.weight', 'decoder.decoders.1.self_attn.linear_out.bias', 'decoder.decoders.1.src_attn.linear_q.weight', 'decoder.decoders.1.src_attn.linear_q.bias', 'decoder.decoders.1.src_attn.linear_k.weight', 'decoder.decoders.1.src_attn.linear_k.bias', 'decoder.decoders.1.src_attn.linear_v.weight', 'decoder.decoders.1.src_attn.linear_v.bias', 'decoder.decoders.1.src_attn.linear_out.weight', 'decoder.decoders.1.src_attn.linear_out.bias', 'decoder.decoders.1.feed_forward.w_1.weight', 'decoder.decoders.1.feed_forward.w_1.bias', 'decoder.decoders.1.feed_forward.w_2.weight', 'decoder.decoders.1.feed_forward.w_2.bias', 'decoder.decoders.1.norm1.weight', 'decoder.decoders.1.norm1.bias', 'decoder.decoders.1.norm2.weight', 'decoder.decoders.1.norm2.bias', 'decoder.decoders.1.norm3.weight', 'decoder.decoders.1.norm3.bias', 'decoder.decoders.2.self_attn.linear_q.weight', 'decoder.decoders.2.self_attn.linear_q.bias', 'decoder.decoders.2.self_attn.linear_k.weight', 'decoder.decoders.2.self_attn.linear_k.bias', 'decoder.decoders.2.self_attn.linear_v.weight', 'decoder.decoders.2.self_attn.linear_v.bias', 'decoder.decoders.2.self_attn.linear_out.weight', 'decoder.decoders.2.self_attn.linear_out.bias', 'decoder.decoders.2.src_attn.linear_q.weight', 'decoder.decoders.2.src_attn.linear_q.bias', 'decoder.decoders.2.src_attn.linear_k.weight', 'decoder.decoders.2.src_attn.linear_k.bias', 'decoder.decoders.2.src_attn.linear_v.weight', 'decoder.decoders.2.src_attn.linear_v.bias', 'decoder.decoders.2.src_attn.linear_out.weight', 'decoder.decoders.2.src_attn.linear_out.bias', 'decoder.decoders.2.feed_forward.w_1.weight', 'decoder.decoders.2.feed_forward.w_1.bias', 'decoder.decoders.2.feed_forward.w_2.weight', 'decoder.decoders.2.feed_forward.w_2.bias', 'decoder.decoders.2.norm1.weight', 'decoder.decoders.2.norm1.bias', 'decoder.decoders.2.norm2.weight', 'decoder.decoders.2.norm2.bias', 'decoder.decoders.2.norm3.weight', 'decoder.decoders.2.norm3.bias'] Unexpected keys: ['encoder.global_cmvn.mean', 'encoder.global_cmvn.istd', 'decoder.left_decoder.embed.0.weight', 'decoder.left_decoder.after_norm.weight', 'decoder.left_decoder.after_norm.bias', 'decoder.left_decoder.output_layer.weight', 'decoder.left_decoder.output_layer.bias', 'decoder.left_decoder.decoders.0.self_attn.linear_q.weight', 'decoder.left_decoder.decoders.0.self_attn.linear_q.bias', 'decoder.left_decoder.decoders.0.self_attn.linear_k.weight', 'decoder.left_decoder.decoders.0.self_attn.linear_k.bias', 'decoder.left_decoder.decoders.0.self_attn.linear_v.weight', 'decoder.left_decoder.decoders.0.self_attn.linear_v.bias', 'decoder.left_decoder.decoders.0.self_attn.linear_out.weight', 'decoder.left_decoder.decoders.0.self_attn.linear_out.bias', 'decoder.left_decoder.decoders.0.src_attn.linear_q.weight', 'decoder.left_decoder.decoders.0.src_attn.linear_q.bias', 'decoder.left_decoder.decoders.0.src_attn.linear_k.weight', 'decoder.left_decoder.decoders.0.src_attn.linear_k.bias', 'decoder.left_decoder.decoders.0.src_attn.linear_v.weight', 'decoder.left_decoder.decoders.0.src_attn.linear_v.bias', 'decoder.left_decoder.decoders.0.src_attn.linear_out.weight', 'decoder.left_decoder.decoders.0.src_attn.linear_out.bias', 'decoder.left_decoder.decoders.0.feed_forward.w_1.weight', 'decoder.left_decoder.decoders.0.feed_forward.w_1.bias', 'decoder.left_decoder.decoders.0.feed_forward.w_2.weight', 'decoder.left_decoder.decoders.0.feed_forward.w_2.bias', 'decoder.left_decoder.decoders.0.norm1.weight', 'decoder.left_decoder.decoders.0.norm1.bias', 'decoder.left_decoder.decoders.0.norm2.weight', 'decoder.left_decoder.decoders.0.norm2.bias', 'decoder.left_decoder.decoders.0.norm3.weight', 'decoder.left_decoder.decoders.0.norm3.bias', 'decoder.left_decoder.decoders.1.self_attn.linear_q.weight', 'decoder.left_decoder.decoders.1.self_attn.linear_q.bias', 'decoder.left_decoder.decoders.1.self_attn.linear_k.weight', 'decoder.left_decoder.decoders.1.self_attn.linear_k.bias', 'decoder.left_decoder.decoders.1.self_attn.linear_v.weight', 'decoder.left_decoder.decoders.1.self_attn.linear_v.bias', 'decoder.left_decoder.decoders.1.self_attn.linear_out.weight', 'decoder.left_decoder.decoders.1.self_attn.linear_out.bias', 'decoder.left_decoder.decoders.1.src_attn.linear_q.weight', 'decoder.left_decoder.decoders.1.src_attn.linear_q.bias', 'decoder.left_decoder.decoders.1.src_attn.linear_k.weight', 'decoder.left_decoder.decoders.1.src_attn.linear_k.bias', 'decoder.left_decoder.decoders.1.src_attn.linear_v.weight', 'decoder.left_decoder.decoders.1.src_attn.linear_v.bias', 'decoder.left_decoder.decoders.1.src_attn.linear_out.weight', 'decoder.left_decoder.decoders.1.src_attn.linear_out.bias', 'decoder.left_decoder.decoders.1.feed_forward.w_1.weight', 'decoder.left_decoder.decoders.1.feed_forward.w_1.bias', 'decoder.left_decoder.decoders.1.feed_forward.w_2.weight', 'decoder.left_decoder.decoders.1.feed_forward.w_2.bias', 'decoder.left_decoder.decoders.1.norm1.weight', 'decoder.left_decoder.decoders.1.norm1.bias', 'decoder.left_decoder.decoders.1.norm2.weight', 'decoder.left_decoder.decoders.1.norm2.bias', 'decoder.left_decoder.decoders.1.norm3.weight', 'decoder.left_decoder.decoders.1.norm3.bias', 'decoder.left_decoder.decoders.2.self_attn.linear_q.weight', 'decoder.left_decoder.decoders.2.self_attn.linear_q.bias', 'decoder.left_decoder.decoders.2.self_attn.linear_k.weight', 'decoder.left_decoder.decoders.2.self_attn.linear_k.bias', 'decoder.left_decoder.decoders.2.self_attn.linear_v.weight', 'decoder.left_decoder.decoders.2.self_attn.linear_v.bias', 'decoder.left_decoder.decoders.2.self_attn.linear_out.weight', 'decoder.left_decoder.decoders.2.self_attn.linear_out.bias', 'decoder.left_decoder.decoders.2.src_attn.linear_q.weight', 'decoder.left_decoder.decoders.2.src_attn.linear_q.bias', 'decoder.left_decoder.decoders.2.src_attn.linear_k.weight', 'decoder.left_decoder.decoders.2.src_attn.linear_k.bias', 'decoder.left_decoder.decoders.2.src_attn.linear_v.weight', 'decoder.left_decoder.decoders.2.src_attn.linear_v.bias', 'decoder.left_decoder.decoders.2.src_attn.linear_out.weight', 'decoder.left_decoder.decoders.2.src_attn.linear_out.bias', 'decoder.left_decoder.decoders.2.feed_forward.w_1.weight', 'decoder.left_decoder.decoders.2.feed_forward.w_1.bias', 'decoder.left_decoder.decoders.2.feed_forward.w_2.weight', 'decoder.left_decoder.decoders.2.feed_forward.w_2.bias', 'decoder.left_decoder.decoders.2.norm1.weight', 'decoder.left_decoder.decoders.2.norm1.bias', 'decoder.left_decoder.decoders.2.norm2.weight', 'decoder.left_decoder.decoders.2.norm2.bias', 'decoder.left_decoder.decoders.2.norm3.weight', 'decoder.left_decoder.decoders.2.norm3.bias', 'decoder.right_decoder.embed.0.weight', 'decoder.right_decoder.after_norm.weight', 'decoder.right_decoder.after_norm.bias', 'decoder.right_decoder.output_layer.weight', 'decoder.right_decoder.output_layer.bias', 'decoder.right_decoder.decoders.0.self_attn.linear_q.weight', 'decoder.right_decoder.decoders.0.self_attn.linear_q.bias', 'decoder.right_decoder.decoders.0.self_attn.linear_k.weight', 'decoder.right_decoder.decoders.0.self_attn.linear_k.bias', 'decoder.right_decoder.decoders.0.self_attn.linear_v.weight', 'decoder.right_decoder.decoders.0.self_attn.linear_v.bias', 'decoder.right_decoder.decoders.0.self_attn.linear_out.weight', 'decoder.right_decoder.decoders.0.self_attn.linear_out.bias', 'decoder.right_decoder.decoders.0.src_attn.linear_q.weight', 'decoder.right_decoder.decoders.0.src_attn.linear_q.bias', 'decoder.right_decoder.decoders.0.src_attn.linear_k.weight', 'decoder.right_decoder.decoders.0.src_attn.linear_k.bias', 'decoder.right_decoder.decoders.0.src_attn.linear_v.weight', 'decoder.right_decoder.decoders.0.src_attn.linear_v.bias', 'decoder.right_decoder.decoders.0.src_attn.linear_out.weight', 'decoder.right_decoder.decoders.0.src_attn.linear_out.bias', 'decoder.right_decoder.decoders.0.feed_forward.w_1.weight', 'decoder.right_decoder.decoders.0.feed_forward.w_1.bias', 'decoder.right_decoder.decoders.0.feed_forward.w_2.weight', 'decoder.right_decoder.decoders.0.feed_forward.w_2.bias', 'decoder.right_decoder.decoders.0.norm1.weight', 'decoder.right_decoder.decoders.0.norm1.bias', 'decoder.right_decoder.decoders.0.norm2.weight', 'decoder.right_decoder.decoders.0.norm2.bias', 'decoder.right_decoder.decoders.0.norm3.weight', 'decoder.right_decoder.decoders.0.norm3.bias', 'decoder.right_decoder.decoders.1.self_attn.linear_q.weight', 'decoder.right_decoder.decoders.1.self_attn.linear_q.bias', 'decoder.right_decoder.decoders.1.self_attn.linear_k.weight', 'decoder.right_decoder.decoders.1.self_attn.linear_k.bias', 'decoder.right_decoder.decoders.1.self_attn.linear_v.weight', 'decoder.right_decoder.decoders.1.self_attn.linear_v.bias', 'decoder.right_decoder.decoders.1.self_attn.linear_out.weight', 'decoder.right_decoder.decoders.1.self_attn.linear_out.bias', 'decoder.right_decoder.decoders.1.src_attn.linear_q.weight', 'decoder.right_decoder.decoders.1.src_attn.linear_q.bias', 'decoder.right_decoder.decoders.1.src_attn.linear_k.weight', 'decoder.right_decoder.decoders.1.src_attn.linear_k.bias', 'decoder.right_decoder.decoders.1.src_attn.linear_v.weight', 'decoder.right_decoder.decoders.1.src_attn.linear_v.bias', 'decoder.right_decoder.decoders.1.src_attn.linear_out.weight', 'decoder.right_decoder.decoders.1.src_attn.linear_out.bias', 'decoder.right_decoder.decoders.1.feed_forward.w_1.weight', 'decoder.right_decoder.decoders.1.feed_forward.w_1.bias', 'decoder.right_decoder.decoders.1.feed_forward.w_2.weight', 'decoder.right_decoder.decoders.1.feed_forward.w_2.bias', 'decoder.right_decoder.decoders.1.norm1.weight', 'decoder.right_decoder.decoders.1.norm1.bias', 'decoder.right_decoder.decoders.1.norm2.weight', 'decoder.right_decoder.decoders.1.norm2.bias', 'decoder.right_decoder.decoders.1.norm3.weight', 'decoder.right_decoder.decoders.1.norm3.bias', 'decoder.right_decoder.decoders.2.self_attn.linear_q.weight', 'decoder.right_decoder.decoders.2.self_attn.linear_q.bias', 'decoder.right_decoder.decoders.2.self_attn.linear_k.weight', 'decoder.right_decoder.decoders.2.self_attn.linear_k.bias', 'decoder.right_decoder.decoders.2.self_attn.linear_v.weight', 'decoder.right_decoder.decoders.2.self_attn.linear_v.bias', 'decoder.right_decoder.decoders.2.self_attn.linear_out.weight', 'decoder.right_decoder.decoders.2.self_attn.linear_out.bias', 'decoder.right_decoder.decoders.2.src_attn.linear_q.weight', 'decoder.right_decoder.decoders.2.src_attn.linear_q.bias', 'decoder.right_decoder.decoders.2.src_attn.linear_k.weight', 'decoder.right_decoder.decoders.2.src_attn.linear_k.bias', 'decoder.right_decoder.decoders.2.src_attn.linear_v.weight', 'decoder.right_decoder.decoders.2.src_attn.linear_v.bias', 'decoder.right_decoder.decoders.2.src_attn.linear_out.weight', 'decoder.right_decoder.decoders.2.src_attn.linear_out.bias', 'decoder.right_decoder.decoders.2.feed_forward.w_1.weight', 'decoder.right_decoder.decoders.2.feed_forward.w_1.bias', 'decoder.right_decoder.decoders.2.feed_forward.w_2.weight', 'decoder.right_decoder.decoders.2.feed_forward.w_2.bias', 'decoder.right_decoder.decoders.2.norm1.weight', 'decoder.right_decoder.decoders.2.norm1.bias', 'decoder.right_decoder.decoders.2.norm2.weight', 'decoder.right_decoder.decoders.2.norm2.bias', 'decoder.right_decoder.decoders.2.norm3.weight', 'decoder.right_decoder.decoders.2.norm3.bias']

Desktop (please complete the following information):

Smartphone (please complete the following information): N/A

Additional context

  1. The missing keys include parameters such as 'encoder.embed.pos_enc.pe', 'decoder.embed.0.weight', and other weights related to the decoder structure.
  2. The unexpected keys include parameters like 'encoder.global_cmvn.mean', 'decoder.left_decoder.embed.0.weight', and others.
  3. I am concerned that these mismatches may affect the fine-tuning process, as my goal is to use this model as a base for further training.
  4. I have checked the README but couldn't find a way to resolve this issue by modifying the train.yaml file. I am wondering if there is an updated version of the model or if this is a known issue that can be safely ignored.
  5. Any guidance or suggestions would be greatly appreciated, especially if there's a way to download a compatible version of the pretrained model.

Thank you!

xingchensong commented 4 days ago

use bitransformerdecoder instead of transformerdecoder