Closed 980202006 closed 2 years ago
import torch
from hear21passt.base import get_basic_model
from hear21passt.models.passt import get_model as get_model_passt
import torch
model = get_basic_model(mode="logits")
# replace the transformer for the 20 classes output
model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476", n_classes=4)
ckpt = torch.load('./output/openmic2008/_None/checkpoints/epoch=0-step=47498.ckpt')
net_statedict = {k[4:]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net.")} # main weights
net_swa = {k[len("net_swa."):]: v.cpu() for k, v in ckpt['state_dict'].items() if k.startswith("net_swa.")} # swa weights
# loading the pretrained model
model.net.load_state_dict(net_statedict)
wave_signal = torch.randn([1,16000*10])
logits = model(wave_signal)
Hi, when you train on shorter audio clips. you only have time positional encoding for that length (I guess for the shape your showing it's 10 seconds). In order to overcome this I suggest 2 possibilities:
Get the predictions of overlapping windows of 10-seconds (or whatever window lengths your model supports), and aggregate these predictions, you can look at this function for example.
You can sample time-positional encoding during training time, to get encoding for longer audios than you have during the training time. This is the way I trained the models that support up to 20s and 30s of Audio. You can use any of these models to fine-tune your model and the resulting model would support 20 or 30 seconds audio. However, for longer clips, you'd still need to do the sliding window.
Thank you!I will try it.
I use a trained model for inference and I encounter this problem when the file length is long. Traceback (most recent call last): File "", line 1, in
File "/home/xingyum/anaconda3/envs/ba3l/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, *kwargs)
File "/home/xingyum/models/PaSST/output/openmic2008/_None/checkpoints/src/hear21passt/hear21passt/wrapper.py", line 38, in forward
x, features = self.net(specs)
File "/home/xingyum/anaconda3/envs/ba3l/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(input, **kwargs)
File "/home/xingyum/models/PaSST/output/openmic2008/_None/checkpoints/src/hear21passt/hear21passt/models/passt.py", line 507, in forward
x = self.forward_features(x)
File "/home/xingyum/models/PaSST/output/openmic2008/_None/checkpoints/src/hear21passt/hear21passt/models/passt.py", line 454, in forward_features
x = x + time_new_pos_embed
RuntimeError: The size of tensor a (2055) must match the size of tensor b (99) at non-singleton dimension 3