YuanGongND / ast

Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".
BSD 3-Clause "New" or "Revised" License
1.07k stars 205 forks source link

Inference time mismatch errors ? #26

Closed Enescigdem closed 2 years ago

Enescigdem commented 2 years ago

Hello, I conducted training of base384 sized ast model on my own data set. While training the was no errors but when I tried to do inference and load from checkpoint error arose.

RuntimeError: Error(s) in loading state_dict for DataParallel: size mismatch for module.v.pos_embed: copying a param with shape torch.Size([1, 602, 768]) from checkpoint, the shape in current model is torch.Size([1, 1214, 768]).

What could be wrong with this error?

YuanGongND commented 2 years ago

How did you initialize the AST model before load_state_dict? Did you use the same input_tdim as that is used for your training? This error means there's a difference in input_tdim between your trained model checkpoint and the AST model object.

Enescigdem commented 2 years ago

I adjusted the input_tdim then ran. The error below appears now.

x = x + self.v.pos_embed RuntimeError: The size of tensor a (1214) must match the size of tensor b (602) at non-singleton dimension 1

YuanGongND commented 2 years ago

Could you paste the code of model initialization for training and inference? Thanks

YuanGongND commented 2 years ago

This means your actual input length is not the same as what you claimed. So are you using audios of the same length for training and inference?

Enescigdem commented 2 years ago

I used the esc50 recipe sh and parameters are follows: freqm=24 timem=96 mixup=0 fstride=10 tstride=10 then in src/run.py audio_model = models.ASTModel(label_dim=args.n_class, fstride=10, tstride=10, input_fdim=128, input_tdim=512, imagenet_pretrain=True, audioset_pretrain=False, model_size='base384') In inference, checkpoint = torch.load(checkpoint_path, map_location='cuda')

audio_model = ASTModel(label_dim=13, fstride=10, tstride=10, input_fdim=128, input_tdim=512, imagenet_pretrain=True, audioset_pretrain=False, model_size='base384', verbose=False)
audio_model = torch.nn.DataParallel(audio_model)
YuanGongND commented 2 years ago

That looks fine, one minor thing is we use audioset_pretrain=True in our ESC50 recipe, which leads to ~8% accuracy improvement on ESC-50, so if your task is also audio event classification, you can consider using that.

The second error you got is likely due to you actually input a longer sequence to the model, can you check the shape of your input to the AST model?

Enescigdem commented 2 years ago

Thank you for your help, I solved the problem it is stemmed from feature making part.

I also want to ask how can i use other image models from timm package in this setup ?

YuanGongND commented 2 years ago

Great to know the problem was solved.

To use other vision models, you need to change ast_model.py, note the DeiT model we use has two CLS tokens, while other vision models typically only have one CLS token, so you also need to modify the forward function; we use timm==0.4.5, new versions have some differences.

Enescigdem commented 2 years ago

Thank you, I'll try that