assafmu / wav2letter_pytorch

An implementation of the Wav2Letter Speech-to-Text model using PyTorch.
14 stars 5 forks source link

Error while running transcribe.py #1

Closed Liranbz closed 4 years ago

Liranbz commented 4 years ago

Hi, Thank you so much for this amazing tool! I tried to use transcribe.py for trained model. I used this arguments:

parser = argparse.ArgumentParser(description='Wav2Letter usage')
parser.add_argument('--test-manifest',metavar='DIR',help='path to test manifest csv', default=r"data/test.csv")
#change to receive file or file list
parser.add_argument('--cuda', default=False, dest='cuda', action='store_true', help='Use cuda to execute model')
parser.add_argument('--seed', type=int, default=1337)
parser.add_argument('--model-path', type=str, default=r"models/2_batchsize/final.pth",
                    help='Path to model.tar to use')
parser.add_argument('--decoder', type=str, default='greedy',
                    help='Type of decoder to use.  "greedy", or "beam". If "beam", can specify LM with to use with "--lm-path"')
parser.add_argument('--lm-path', type=str, default='',
                    help='Path to arpa lm file to use for testing. Default is no LM.')
parser.add_argument('--beam-search-params', type=str, default='5,0.3,5,1e-3',
                    help='comma separated value for k,alpha,beta,prune. For example, 5,0.3,5,1e-3')
parser.add_argument('--arc', default='wav2letter', type=str,
                    help='Network architecture to use. Can be either "quartz"  or "wav2letter" (default)')
parser.add_argument('--mel-spec-count', default=0, type=int, help='How many channels to use in Mel Spectrogram')
parser.add_argument('--use-mel-spec', dest='mel_spec_count', action='store_const', const=64,
                    help='Use mel spectrogram with default value (64)')

and I got this error:

Traceback (most recent call last):
  File "C:/Users/liran_bz/PycharmProjects/wav2letter-decoder/transcribe.py", line 93, in <module>
    transcribe(**vars(arguments))
  File "C:/Users/liran_bz/PycharmProjects/wav2letter-decoder/transcribe.py", line 61, in transcribe
    model = get_model(kwargs)
  File "C:/Users/liran_bz/PycharmProjects/wav2letter-decoder/transcribe.py", line 53, in get_model
    model = arc.load_model(kwargs['model_path'])
  File "C:\Users\liran_bz\PycharmProjects\wav2letter-decoder\wav2letter.py", line 114, in load_model
    return cls.load_model_package(package)
  File "C:\Users\liran_bz\PycharmProjects\wav2letter-decoder\wav2letter.py", line 119, in load_model_package
    model.load_state_dict(package['state_dict'])
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Wav2Letter:
    Missing key(s) in state_dict: "conv1ds.conv1d_0.conv1.weight", "conv1ds.conv1d_0.conv1.bias", "conv1ds.conv1d_0.batch_norm.weight", "conv1ds.conv1d_0.batch_norm.bias", "conv1ds.conv1d_0.batch_norm.running_mean", "conv1ds.conv1d_0.batch_norm.running_var", "conv1ds.conv1d_1.conv1.weight", "conv1ds.conv1d_1.conv1.bias", "conv1ds.conv1d_1.batch_norm.weight", "conv1ds.conv1d_1.batch_norm.bias", "conv1ds.conv1d_1.batch_norm.running_mean", "conv1ds.conv1d_1.batch_norm.running_var", "conv1ds.conv1d_2.conv1.weight", "conv1ds.conv1d_2.conv1.bias", "conv1ds.conv1d_2.batch_norm.weight", "conv1ds.conv1d_2.batch_norm.bias", "conv1ds.conv1d_2.batch_norm.running_mean", "conv1ds.conv1d_2.batch_norm.running_var", "conv1ds.conv1d_3.conv1.weight", "conv1ds.conv1d_3.conv1.bias", "conv1ds.conv1d_3.batch_norm.weight", "conv1ds.conv1d_3.batch_norm.bias", "conv1ds.conv1d_3.batch_norm.running_mean", "conv1ds.conv1d_3.batch_norm.running_var", "conv1ds.conv1d_4.conv1.weight", "conv1ds.conv1d_4.conv1.bias", "conv1ds.conv1d_4.batch_norm.weight", "conv1ds.conv1d_4.batch_norm.bias", "conv1ds.conv1d_4.batch_norm.running_mean", "conv1ds.conv1d_4.batch_norm.running_var", "conv1ds.conv1d_5.conv1.weight", "conv1ds.conv1d_5.conv1.bias", "conv1ds.conv1d_5.batch_norm.weight", "conv1ds.conv1d_5.batch_norm.bias", "conv1ds.conv1d_5.batch_norm.running_mean", "conv1ds.conv1d_5.batch_norm.running_var", "conv1ds.conv1d_6.conv1.weight", "conv1ds.conv1d_6.conv1.bias". 
    Unexpected key(s) in state_dict: "jasper_encoder.0.mconv.0.conv.weight", "jasper_encoder.0.mconv.1.conv.weight", "jasper_encoder.0.mconv.2.weight", "jasper_encoder.0.mconv.2.bias", "jasper_encoder.0.mconv.2.running_mean", "jasper_encoder.0.mconv.2.running_var", "jasper_encoder.0.mconv.2.num_batches_tracked", "jasper_encoder.1.mconv.0.conv.weight", "jasper_encoder.1.mconv.1.conv.weight", "jasper_encoder.1.mconv.2.weight", "jasper_encoder.1.mconv.2.bias", "jasper_encoder.1.mconv.2.running_mean", "jasper_encoder.1.mconv.2.running_var", "jasper_encoder.1.mconv.2.num_batches_tracked", "jasper_encoder.1.res.0.0.conv.weight", "jasper_encoder.1.res.0.1.weight", "jasper_encoder.1.res.0.1.bias", "jasper_encoder.1.res.0.1.running_mean", "jasper_encoder.1.res.0.1.running_var", "jasper_encoder.1.res.0.1.num_batches_tracked", "jasper_encoder.2.mconv.0.conv.weight", "jasper_encoder.2.mconv.1.conv.weight", "jasper_encoder.2.mconv.2.weight", "jasper_encoder.2.mconv.2.bias", "jasper_encoder.2.mconv.2.running_mean", "jasper_encoder.2.mconv.2.running_var", "jasper_encoder.2.mconv.2.num_batches_tracked", "jasper_encoder.2.res.0.0.conv.weight", "jasper_encoder.2.res.0.1.weight", "jasper_encoder.2.res.0.1.bias", "jasper_encoder.2.res.0.1.running_mean", "jasper_encoder.2.res.0.1.running_var", "jasper_encoder.2.res.0.1.num_batches_tracked", "jasper_encoder.3.mconv.0.conv.weight", "jasper_encoder.3.mconv.1.conv.weight", "jasper_encoder.3.mconv.2.weight", "jasper_encoder.3.mconv.2.bias", "jasper_encoder.3.mconv.2.running_mean", "jasper_encoder.3.mconv.2.running_var", "jasper_encoder.3.mconv.2.num_batches_tracked", "jasper_encoder.3.res.0.0.conv.weight", "jasper_encoder.3.res.0.1.weight", "jasper_encoder.3.res.0.1.bias", "jasper_encoder.3.res.0.1.running_mean", "jasper_encoder.3.res.0.1.running_var", "jasper_encoder.3.res.0.1.num_batches_tracked", "final_layer.0.weight", "final_layer.0.bias". 

May I need to add\ change params? Thank you!

assafmu commented 4 years ago

It would seem the default model architecture is different between transcribe.py and train.py. In train.py it is quartznet (implmented in the Jasper class), in transcribe.py it is Wav2Letter. This causes an error when loading the model from a file, since the weights don't match.

For now, consider passing --arc quartz for the transcribe.py script.

Keeping this open until we make default model architecture (and other parameters) consistent.

Liranbz commented 4 years ago

You right, model was trained with --arc quartz and not in Wav2Letter / Thank you!