minzwon / sota-music-tagging-models

MIT License
397 stars 64 forks source link

Error in loading the model during the training process #1

Closed sainathadapa closed 4 years ago

sainathadapa commented 4 years ago

Hi, I'm trying to retrain the ShortChunkCNN model for MagnaTagATune dataset. The training process is error-ring out at the 80th epoch, when the self.load is called from the opt_schedule:

RuntimeError: Error(s) in loading state_dict for ShortChunkCNN:
        size mismatch for spec.mel_scale.fb: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([257, 128]).

Can you tell me why this is happening? Also, can you tell me why the following snippet of code is required in self.load?

if 'spec.mel_scale.fb' in S.keys():
            S['spec.mel_scale.fb'] = torch.tensor([])

Thanks for the nice paper, and neat code!

minzwon commented 4 years ago

Hi, thank you for your interest!

It is because of the initialization issue of torchaudio.transforms.MelSpectrogram. When you first initialize a model, shape of model.spec.mel_scale.fb is torch.Size([0]). However, after you perform model.spec(x) or some other actions, the shape becomes torch.Size([257, 128]).

Previously, I tried to ignore this issue by

if 'spec.mel_scale.fb' in S.keys():
    S['spec.mel_scale.fb'] = torch.tensor([])

but it looks like a version dependent temporary solution.

I fixed the lines as follows.

if 'spec.mel_scale.fb' in S.keys():
    S['spec.mel_scale.fb'] = S['spec.mel_scale.fb']

If you pull the most recent version, it will work now.

Best, Minz

minzwon commented 4 years ago

@sainathadapa did it solve your problem?

sainathadapa commented 4 years ago

It did, thanks!