Open nandacv opened 1 year ago
Hi, thank you!
I think in order to get the best performance, it's better to retrain on 16khz. Alternatively, you can adapt the pre-trained model to accept 16khz input like this:
First get the models as usual:
from hear21passt.base import get_basic_model, get_model_passt
model = get_basic_model(mode="logits")
Then replace the mel layer with this adapted config:
from hear21passt.models.preprocess import AugmentMelSTFT
model.mel = AugmentMelSTFT(n_mels=128, sr=16000, win_length=400, hopsize=160, n_fft=512, freqm=48,
timem=192,
htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
fmax_aug_range=1000)
you can comapre it with original mel layer here: https://github.com/kkoutini/passt_hear21/blob/4dd6b9e426f528e2e8409b9bacecf58a2f464548/hear21passt/base.py#L52 The main difference were in the original: sr=32000, win_length=800, hopsize=320, n_fft=1024 I hope this helps.
The audio files I downloaded where in 32khz
Thank you for the reply. Can you please confirm if the following code looks good?
from hear21passt.base import get_basic_model,get_model_passt import torch
model = get_basic_model(mode="logits") print(model.mel) # Extracts mel spectrogram from raw waveforms. from hear21passt.models.preprocess import AugmentMelSTFT model.mel = AugmentMelSTFT(n_mels=128, sr=16000, win_length=400, hopsize=160, n_fft=512, freqm=48, timem=192, htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10, fmax_aug_range=1000)
model.eval() with torch.no_grad():
#example audio_wave of batch=3 and 10 seconds
audio = torch.ones((3, 16000 * 10))*0.5
logits=model(audio)
Also I assume, these logits should be followed by application of sigmoid function to get the output classes? Please correct me if I am wrong.
Thanks in advance.
yes, this looks correct.
Thank you for the code and inference script. I understand that the PaSST model has been trained on AudioSet with sampling rate of 32kHz. I am trying to make inference using the pre trained model. Could you please let me know if I have to retrain the model with AudioSet (sampling rate of 16kHz) data to use it to make inference on 16kHz data or is there any other way?
Also, curious to know why did you use 32kHz instead of already available 16kHz AudioSet data?
Thanks in advance.