kkoutini / PaSST

Efficient Training of Audio Transformers with Patchout
Apache License 2.0
305 stars 50 forks source link

Inference Issue #31

Open Jerry2001 opened 1 year ago

Jerry2001 commented 1 year ago

Hello,

First of all, thank you for the awesome and very well-written paper and repo.

I currently want to use the embedding of these pre-trained models for my project. The following is the inference code I wrote for fsd50k.

import torch
import numpy as np
import librosa
from hear21passt.base import get_basic_model, get_model_passt, get_scene_embeddings, get_timestamp_embeddings, load_model

model = get_basic_model(mode="logits")
model.net = get_model_passt(arch="fsd50k-n",  n_classes=200, fstride=16, tstride=16)
model.eval()
model = model.cuda()

audio, sr = librosa.load("../dataset/fsd50k/mp3/FSD50K.dev_audio/102863.mp3", sr = 32000, mono=True)
audio = torch.from_numpy(np.array([audio]))
audio_batch = torch.cat((audio, audio, audio), 0).cuda()

embed = get_scene_embeddings(audio_batch, model)
model(audio_batch)

When I do embed.shape I get torch.Size([3, 1295]), so I basically get what I need already. But, I double check to try get the logit through model() and it give me the following error:

RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_13937/329924078.py in <module>
----> 1 model(audio_batch)

/data/scratch/ngop/.envs/vqgan2/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/data/scratch/ngop/src/hear21passt/hear21passt/wrapper.py in forward(self, x)
     36         specs = self.mel(x)
     37         specs = specs.unsqueeze(1)
---> 38         x, features = self.net(specs)
     39         if self.mode == "all":
     40             embed = torch.cat([x, features], dim=1)

/data/scratch/ngop/.envs/vqgan2/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/data/scratch/ngop/src/hear21passt/hear21passt/models/passt.py in forward(self, x)
    525         if first_RUN: print("x", x.size())
    526 
--> 527         x = self.forward_features(x)
    528 
    529         if self.head_dist is not None:

/data/scratch/ngop/src/hear21passt/hear21passt/models/passt.py in forward_features(self, x)
    472             time_new_pos_embed = time_new_pos_embed[:, :, :, :x.shape[-1]]
    473             if first_RUN: print(" CUT time_new_pos_embed.shape", time_new_pos_embed.shape)
--> 474         x = x + time_new_pos_embed
    475         if first_RUN: print(" self.freq_new_pos_embed.shape", self.freq_new_pos_embed.shape)
    476         x = x + self.freq_new_pos_embed

RuntimeError: The size of tensor a (135) must match the size of tensor b (99) at non-singleton dimension 3

However, I tried a few other audios in fsd50k, and some were able to give me logits and the correct prediction, but some just give errors like this. What could the issues be? Do I need to worry about it, or could I use the embedding? My other question is whether the input batch is fixed? For the model I loaded, I have to input the batch of 3 audio. Is there a way for me to input a different batch?

Jerry2001 commented 1 year ago

Also, for fsd50k, for sound files with different lengths in the same batch, should I pad them with 0 to the same length before passing them to the model?

kkoutini commented 1 year ago

Hi! thank you for your interest! The problem is that the model you've loaded was trained on 10-second clips (Audioset and cropped FSD50k) and audio file that you're processing is longer than 10 seconds (must be around 13.5 seconds from the error) and therefore there is not enough trained time pos encoding to cover the 13.5 seconds. The get_scene_embeddings takes care of this, by checking if the audio is longer than the largest legnth the model can handle here This is only a problem for inputs longer than 10-seconds, the model can handle shorter clips here by cropping the time positional encodings to match the input. If you use batched inputs, then you can pad shorter clips. If you're doing the inference one by one then the only constraint is to have enough time positional encodings to cover the whole input. One possible work around is to get (overlapping) windows of 10 seconds and average the resulting embeding, this is done here

During training I'm cropping and padding the raw waveforms with zeros here.

I hope this helps.