kkoutini / PaSST

Efficient Training of Audio Transformers with Patchout
Apache License 2.0
295 stars 50 forks source link

Inference on AudioSet #37

Open nandacv opened 1 year ago

nandacv commented 1 year ago

Thank you for the code and inference script. I understand that the PaSST model has been trained on AudioSet with sampling rate of 32kHz. I am trying to make inference using the pre trained model. Could you please let me know if I have to retrain the model with AudioSet (sampling rate of 16kHz) data to use it to make inference on 16kHz data or is there any other way?

Also, curious to know why did you use 32kHz instead of already available 16kHz AudioSet data?

Thanks in advance.

kkoutini commented 1 year ago

Hi, thank you!

I think in order to get the best performance, it's better to retrain on 16khz. Alternatively, you can adapt the pre-trained model to accept 16khz input like this:

First get the models as usual:

from hear21passt.base import get_basic_model, get_model_passt

model = get_basic_model(mode="logits")

Then replace the mel layer with this adapted config:


from hear21passt.models.preprocess import AugmentMelSTFT

model.mel =  AugmentMelSTFT(n_mels=128, sr=16000, win_length=400, hopsize=160, n_fft=512, freqm=48,
                         timem=192,
                         htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10,
                         fmax_aug_range=1000)

you can comapre it with original mel layer here: https://github.com/kkoutini/passt_hear21/blob/4dd6b9e426f528e2e8409b9bacecf58a2f464548/hear21passt/base.py#L52 The main difference were in the original: sr=32000, win_length=800, hopsize=320, n_fft=1024 I hope this helps.

The audio files I downloaded where in 32khz

nandacv commented 1 year ago

Thank you for the reply. Can you please confirm if the following code looks good?

from hear21passt.base import get_basic_model,get_model_passt import torch

get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer

model = get_basic_model(mode="logits") print(model.mel) # Extracts mel spectrogram from raw waveforms. from hear21passt.models.preprocess import AugmentMelSTFT model.mel = AugmentMelSTFT(n_mels=128, sr=16000, win_length=400, hopsize=160, n_fft=512, freqm=48, timem=192, htk=False, fmin=0.0, fmax=None, norm=1, fmin_aug_range=10, fmax_aug_range=1000)

example inference

model.eval() with torch.no_grad():

audio_wave has the shape of [batch, seconds*16000] sampling rate is 16k

#example audio_wave of batch=3 and 10 seconds
audio = torch.ones((3, 16000 * 10))*0.5
logits=model(audio)

Also I assume, these logits should be followed by application of sigmoid function to get the output classes? Please correct me if I am wrong.

Thanks in advance.

kkoutini commented 1 year ago

yes, this looks correct.