kkoutini / PaSST

Efficient Training of Audio Transformers with Patchout
Apache License 2.0
287 stars 48 forks source link

difference of fine-tuning the pretrained models #30

Open nianniandoushabao opened 1 year ago

nianniandoushabao commented 1 year ago

I'm sorry to bother you. I want to ask the difference between the two ways to get pre-training models. I don't know if I understand correctly The first is in the ''Getting a pre-trained model for fine tuning'' part. The code is

from hear21passt.base import get_basic_model,get_model_passt
import torch
# get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer
model = get_basic_model(mode="logits")
print(model.mel) # Extracts mel spectrogram from raw waveforms.

# optional replace the transformer with one that has the required number of classes i.e. 50
model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476",  n_classes=50)
print(model.net) # the transformer network.

# now model contains mel + the transformer pre-trained model ready to be fine tuned.
# It's still expecting input of the shape [batch, seconds*32000] sampling rate is 32k

model.train()
model = model.cuda()

The second is in the ''Pre-trained models'' part.

from models.passt import get_model
model  = get_model(arch="passt_s_swa_p16_128_ap476", pretrained=True, n_classes=527, in_channels=1,
                   fstride=10, tstride=10,input_fdim=128, input_tdim=998,
                   u_patchout=0, s_patchout_t=40, s_patchout_f=4)

I have two questions. Is it the first way to obtain the pre-trained model and only fine-tune the layers in transformer blocks related to num_classes? The other layers' weights will not changed? And Is it the second way to obtain the pre-trained model will load weights of all layers and train them again ? Or are the two ways the same?

kkoutini commented 1 year ago

Hi! get_model_passt is an alias for the get_model function in passt_hear21 see here. So in general there should not be a difference between the two functions, they will both return the transformer model. The passt_hear21 repo contains only the models code (including the preprocessing) without the dependecies of saced and PyTorch Lightning.

if num_classes is different then that of the pre-trained models, then the last MLP layer is initilizaed to random (according to the new num_classes`) and the model is ready to be fine-tuned on the new task, that is done in timm here so both ways are the same. Off-course you can use all the other arguments here.

hear21passt.base.get_basic_model will return a wrapper module that contains both the transformer and the preprocessing mel-spectrograms see here

nianniandoushabao commented 1 year ago

Oh, thank you, thank you very much.