kkoutini / PaSST

Efficient Training of Audio Transformers with Patchout
Apache License 2.0
287 stars 48 forks source link

RuntimeError: The size of tensor a (2055) must match the size of tensor b (99) at non-singleton dimension 3 #19

Closed 980202006 closed 2 years ago

980202006 commented 2 years ago

I use a trained model for inference and I encounter this problem when the file length is long. Traceback (most recent call last): File "", line 1, in File "/home/xingyum/anaconda3/envs/ba3l/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, *kwargs) File "/home/xingyum/models/PaSST/output/openmic2008/_None/checkpoints/src/hear21passt/hear21passt/wrapper.py", line 38, in forward x, features = self.net(specs) File "/home/xingyum/anaconda3/envs/ba3l/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, **kwargs) File "/home/xingyum/models/PaSST/output/openmic2008/_None/checkpoints/src/hear21passt/hear21passt/models/passt.py", line 507, in forward x = self.forward_features(x) File "/home/xingyum/models/PaSST/output/openmic2008/_None/checkpoints/src/hear21passt/hear21passt/models/passt.py", line 454, in forward_features x = x + time_new_pos_embed RuntimeError: The size of tensor a (2055) must match the size of tensor b (99) at non-singleton dimension 3 image

980202006 commented 2 years ago
import torch

from hear21passt.base import  get_basic_model
from hear21passt.models.passt import get_model as get_model_passt
import torch
model = get_basic_model(mode="logits")
# replace the transformer for the  20 classes output
model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476", n_classes=4)
ckpt = torch.load('./output/openmic2008/_None/checkpoints/epoch=0-step=47498.ckpt')

net_statedict = {k[4:]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net.")} # main weights
net_swa  = {k[len("net_swa."):]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net_swa.")} # swa weights

# loading the pretrained model
model.net.load_state_dict(net_statedict)
wave_signal = torch.randn([1,16000*10])
logits = model(wave_signal)
kkoutini commented 2 years ago

Hi, when you train on shorter audio clips. you only have time positional encoding for that length (I guess for the shape your showing it's 10 seconds). In order to overcome this I suggest 2 possibilities:

  1. Get the predictions of overlapping windows of 10-seconds (or whatever window lengths your model supports), and aggregate these predictions, you can look at this function for example.

  2. You can sample time-positional encoding during training time, to get encoding for longer audios than you have during the training time. This is the way I trained the models that support up to 20s and 30s of Audio. You can use any of these models to fine-tune your model and the resulting model would support 20 or 30 seconds audio. However, for longer clips, you'd still need to do the sliding window.

980202006 commented 2 years ago

Thank you!I will try it.