qiuqiangkong / panns_inference

MIT License
197 stars 31 forks source link

Error(s) in loading state_dict for Cnn14: #8

Open callzhang opened 3 years ago

callzhang commented 3 years ago

Strictly following the sample code yields the following error:

Exception has occurred: RuntimeError
Error(s) in loading state_dict for Cnn14:
    size mismatch for fc_audioset.weight: copying a param with shape torch.Size([527, 2048]) from checkpoint, the shape in current model is torch.Size([0, 2048]).
    size mismatch for fc_audioset.bias: copying a param with shape torch.Size([527]) from checkpoint, the shape in current model is torch.Size([0]).
  File "/home/stardust/algorithms-playground/audio/kedaxunfei/audio_event_detection.py", line 19, in <module>
    at = AudioTagging(checkpoint_path=None, device='cuda')

code to reproduce:

paths = glob('path/to/audio/*.mp3')
audio_path = paths[0]
(audio, _) = librosa.core.load(audio_path, sr=32000, mono=True)
audio = audio[None, :]  # (batch_size, segment_samples)

print('------ Audio tagging ------')
at = AudioTagging(checkpoint_path=None, device='cuda')
(clipwise_output, embedding) = at.inference(audio)

print('------ Sound event detection ------')
sed = SoundEventDetection(checkpoint_path=None, device='cuda')
framewise_output = sed.inference(audio)
qiuqiangkong commented 3 years ago

Hi, one solution is to check the classes_num you defined. It seems is 0 now, but should be 527.

On Thu, 6 May 2021 at 01:18, Derek Zhang @.***> wrote:

Strictly following the sample code yields the following error:

Exception has occurred: RuntimeError Error(s) in loading state_dict for Cnn14: size mismatch for fc_audioset.weight: copying a param with shape torch.Size([527, 2048]) from checkpoint, the shape in current model is torch.Size([0, 2048]). size mismatch for fc_audioset.bias: copying a param with shape torch.Size([527]) from checkpoint, the shape in current model is torch.Size([0]). File "/home/stardust/algorithms-playground/audio/kedaxunfei/audio_event_detection.py", line 19, in at = AudioTagging(checkpoint_path=None, device='cuda')

code to reproduce:

paths = glob('path/to/audio/*.mp3') audiopath = paths[0] (audio, ) = librosa.core.load(audio_path, sr=32000, mono=True) audio = audio[None, :] # (batch_size, segment_samples)

print('------ Audio tagging ------') at = AudioTagging(checkpoint_path=None, device='cuda') (clipwise_output, embedding) = at.inference(audio)

print('------ Sound event detection ------') sed = SoundEventDetection(checkpoint_path=None, device='cuda') framewise_output = sed.inference(audio)

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/qiuqiangkong/panns_inference/issues/8, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADFXTSJKGZ3G54OAOHP4HVLTMF4UZANCNFSM44FNME5Q .

callzhang commented 3 years ago

But I didn't change anything after installing the package from pip. Any idea?

qiuqiangkong commented 3 years ago

Hi, did you manage to download the checkpoint successfully?

On Thu, 6 May 2021 at 10:43, Derek Zhang @.***> wrote:

But I didn't change anything after install the package from pip. Any idea?

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/qiuqiangkong/panns_inference/issues/8#issuecomment-833183664, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADFXTSJBP7EY6JVTF75HRKLTMH64TANCNFSM44FNME5Q .

Dexter1618 commented 3 years ago

Using the new 16 KHz model as updated on 24th August 2020 gave me a similar error. I haven't even passed a audio signal yet.

Code line who caused the error : AudioTagging(checkpoint_path = "Cnn14_16k_mAP%3D0.438.pth", device = "cuda") Error:

RuntimeError: Error(s) in loading state_dict for Cnn14: size mismatch for spectrogram_extractor.stft.conv_real.weight: copying a param with shape torch.Size([257, 1, 512]) from checkpoint, the shape in current model is torch.Size([513, 1, 1024]). size mismatch for spectrogram_extractor.stft.conv_imag.weight: copying a param with shape torch.Size([257, 1, 512]) from checkpoint, the shape in current model is torch.Size([513, 1, 1024]). size mismatch for logmel_extractor.melW: copying a param with shape torch.Size([257, 64]) from checkpoint, the shape in current model is torch.Size([513, 64]).

Please advise @qiuqiangkong

qiuqiangkong commented 3 years ago

Hi Debayan,

 This is because the window_size, etc. should be the same as the

checkpoint. Try:

CHECKPOINT_PATH="Cnn14_16k_mAP=0.438.pth" # Trained by a later code version, achieves higher mAP than the paper. wget -O $CHECKPOINT_PATH https://zenodo.org/record/3987831/files/Cnn14_16k_mAP%3D0.438.pth?download=1 MODEL_TYPE="Cnn14_16k" CUDA_VISIBLE_DEVICES=0 python3 pytorch/inference.py audio_tagging \ --sample_rate=16000 \ --window_size=512 \ --hop_size=160 \ --mel_bins=64 \ --fmin=50 \ --fmax=8000 \ --model_type=$MODEL_TYPE \ --checkpoint_path=$CHECKPOINT_PATH \ --audio_path='resources/R9_ZSCveAHg_7s.wav' \ --cuda

On Fri, 30 Jul 2021 at 20:28, Debayan Das @.***> wrote:

Using the new 16 KHz model as updated on 24th August 2020 gave me a similar error. I haven't even passed a audio signal yet.

Code line who caused the error : AudioTagging(checkpoint_path = "Cnn14_16k_mAP%3D0.438.pth", device = "cuda") Error:

RuntimeError: Error(s) in loading state_dict for Cnn14: size mismatch for spectrogram_extractor.stft.conv_real.weight: copying a param with shape torch.Size([257, 1, 512]) from checkpoint, the shape in current model is torch.Size([513, 1, 1024]). size mismatch for spectrogram_extractor.stft.conv_imag.weight: copying a param with shape torch.Size([257, 1, 512]) from checkpoint, the shape in current model is torch.Size([513, 1, 1024]). size mismatch for logmel_extractor.melW: copying a param with shape torch.Size([257, 64]) from checkpoint, the shape in current model is torch.Size([513, 64]).

Please advise @qiuqiangkong https://github.com/qiuqiangkong

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/qiuqiangkong/panns_inference/issues/8#issuecomment-889859793, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADFXTSM5KADPNMHMBZUIXODT2KLGJANCNFSM44FNME5Q .