qiuqiangkong / audioset_tagging_cnn

MIT License
1.32k stars 249 forks source link

Shape doesn't match when inferencing Cnn14_16k model #16

Open jeremmyzong opened 4 years ago

jeremmyzong commented 4 years ago

Great work! And appreciate for sharing!

When I run this code according to readme:

python pytorch/inference.py audio_tagging --sample_rate=16000 --window_size=512 --hop_size=160 --mel_bins=64 --fmin=50 --fmax=8000 --model_type="Cnn14_16k" --checkpoint_path="Cnn14_16k_mAP=0.438.pth" --audio_path='resources/R9_ZSCveAHg_7s.mp3'

raise error:

`

Traceback (most recent call last): File "pytorch/inference.py", line 201, in audio_tagging(args) File "pytorch/inference.py", line 42, in audio_tagging model.load_state_dict(checkpoint['model']) File "/home/zongbowen/anaconda2/envs/tensorflow/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1045, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Cnn14_16k: size mismatch for spectrogram_extractor.stft.conv_real.weight: copying a param with shape torch.Size([257, 1, 512]) from checkpoint, the shape in current model is torch.Size([129, 1, 256]). size mismatch for spectrogram_extractor.stft.conv_imag.weight: copying a param with shape torch.Size([257, 1, 512]) from checkpoint, the shape in current model is torch.Size([129, 1, 256]). size mismatch for logmel_extractor.melW: copying a param with shape torch.Size([257, 64]) from checkpoint, the shape in current model is torch.Size([129, 64]).

`

qiuqiangkong commented 3 years ago

Hi, does the default Cnn14 model (32 kHz) work?